diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..9d6bb2c6c165eabaf24604396cac0d4ab2a2224f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +Assets/ExampleInput/music.wav filter=lfs diff=lfs merge=lfs -text +Assets/ExampleInput/soundeffects.wav filter=lfs diff=lfs merge=lfs -text +Assets/ExampleInput/speech.wav filter=lfs diff=lfs merge=lfs -text +Assets/Figure.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f5ce32c448f943505687920a52cb9d566c82117f --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +*.pyc +*.pyo +.eggs/ +*.egg-info/ +dist/ +build/ +.env diff --git a/Assets/ExampleInput/music.wav b/Assets/ExampleInput/music.wav new file mode 100644 index 0000000000000000000000000000000000000000..5088ffd7e19bd52d042e59f4eab43826d213e3f5 --- /dev/null +++ b/Assets/ExampleInput/music.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57031bac5cdc14b2be1c50811eca257070bf1fa9fb98d2aebeb7eda86c67ceaa +size 491564 diff --git a/Assets/ExampleInput/soundeffects.wav b/Assets/ExampleInput/soundeffects.wav new file mode 100644 index 0000000000000000000000000000000000000000..5627ac6ec5c098d1f0d7f66af10dcd8247220851 --- /dev/null +++ b/Assets/ExampleInput/soundeffects.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96b31750ddb93397789224ddc047ee21cb34760f536818e9ea4cd2328d8dfe69 +size 491564 diff --git a/Assets/ExampleInput/speech.wav b/Assets/ExampleInput/speech.wav new file mode 100644 index 0000000000000000000000000000000000000000..09523e201abdfdb851f26da8ff2849c90b47e50e --- /dev/null +++ b/Assets/ExampleInput/speech.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:046907c4a7a5bcb9f1b1bb49623fdc4d01be215d6c2783ff5157150938ed21a2 +size 491564 diff --git a/Assets/Figure.png b/Assets/Figure.png new file mode 100644 index 0000000000000000000000000000000000000000..c532d9c7ba9e3d6cdea158dc178541d127b5c203 --- /dev/null +++ b/Assets/Figure.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70639d9a3d6fd61ac73b97104cc082e816a81b4b211ac9092cfee557d8aff6c7 +size 1090177 diff --git a/FlashSR/AudioSR/AudioSRUnet.py b/FlashSR/AudioSR/AudioSRUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1ee04b5f0187223fd16439f4e5dbacbec43027aa --- /dev/null +++ b/FlashSR/AudioSR/AudioSRUnet.py @@ -0,0 +1,1127 @@ +''' +from TorchJaekwon.Util.Util import Util +Util.set_sys_path_to_parent_dir(__file__,3) +import sys, os +sys.path.append(os.path.dirname(os.path.dirname(__file__))) +''' +################################################################################ +#just copy from audiosr/latent_diffusion/modules/diffusionmodules/openaimodel.py +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from FlashSR.AudioSR.latent_diffusion.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from FlashSR.AudioSR.latent_diffusion.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1).contiguous() # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context_list=None, mask_list=None): + # The first spatial transformer block does not have context + spatial_transformer_id = 0 + context_list = [None] + context_list + mask_list = [None] + mask_list + + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + if spatial_transformer_id >= len(context_list): + context, mask = None, None + else: + context, mask = ( + context_list[spatial_transformer_id], + mask_list[spatial_transformer_id], + ) + + x = layer(x, context, mask=mask) + spatial_transformer_id += 1 + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1).contiguous() + qkv = self.qkv(self.norm(x)).contiguous() + h = self.attention(qkv).contiguous() + h = self.proj_out(h).contiguous() + return (x + h).reshape(b, c, *spatial).contiguous() + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = ( + qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1) + ) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum( + "bts,bcs->bct", + weight, + v.reshape(bs * self.n_heads, ch, length).contiguous(), + ) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class AudioSRUnet(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size:int = 64, + in_channels:int = 32, + model_channels:int = 128, + out_channels:int = 16, + num_res_blocks:int = 2, + attention_resolutions:list = [8, 4, 2], + dropout=0, + channel_mult=[1, 2, 3, 5], + conv_resample=True, + dims=2, + extra_sa_layer=True, + num_classes=None, + extra_film_condition_dim=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=32, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=True, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.extra_film_condition_dim = extra_film_condition_dim + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + # assert not ( + # self.num_classes is not None and self.extra_film_condition_dim is not None + # ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim." + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.use_extra_film_by_concat = self.extra_film_condition_dim is not None + + if self.extra_film_condition_dim is not None: + self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim) + print( + "+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " + % self.extra_film_condition_dim + ) + + if context_dim is not None and not use_spatial_transformer: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + + if context_dim is not None and not isinstance(context_dim, list): + context_dim = [context_dim] + elif context_dim is None: + context_dim = [None] # At least use one spatial transformer + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if extra_sa_layer: + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=None, + ) + ) + for context_dim_id in range(len(context_dim)): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + middle_layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + if extra_sa_layer: + middle_layers.append( + SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=None + ) + ) + for context_dim_id in range(len(context_dim)): + middle_layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + middle_layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + self.middle_block = TimestepEmbedSequential(*middle_layers) + + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if extra_sa_layer: + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=None, + ) + ) + for context_dim_id in range(len(context_dim)): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + self.shape_reported = False + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward( + self, + x, + timesteps=None, + y=None, + context_list=list(), + context_attn_mask_list=list(), + **kwargs, + ): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional + :return: an [N x C x ...] Tensor of outputs. + """ + x = th.concat([x,y], dim=1) #jakeoneijk added + y = None + if not self.shape_reported: + # print("The shape of UNet input is", x.size()) + self.shape_reported = True + + assert (y is not None) == ( + self.num_classes is not None or self.extra_film_condition_dim is not None + ), "must specify y if and only if the model is class-conditional or film embedding conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + # if self.num_classes is not None: + # assert y.shape == (x.shape[0],) + # emb = emb + self.label_emb(y) + + if self.use_extra_film_by_concat: + emb = th.cat([emb, self.film_emb(y)], dim=-1) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context_list, context_attn_mask_list) + hs.append(h) + h = self.middle_block(h, emb, context_list, context_attn_mask_list) + for module in self.output_blocks: + concate_tensor = hs.pop() + h = th.cat([h, concate_tensor], dim=1) + h = module(h, emb, context_list, context_attn_mask_list) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + +if __name__ == '__main__': + ''' + + args = { + 'in_channels': 2, + 'model_channels': 64, + 'out_channels': 1 + } + audio_sr = AudioSRUnet(**args) + audio_sr(x = th.randn(1, 1, 128, 256), timesteps = th.tensor([30]), y = th.randn(1, 1, 128, 256)) #jakeoneijk added +''' + audio_sr = AudioSRUnet() + audio_sr(x = th.randn(1, 16, 64, 32), timesteps = th.tensor([30]), y = th.randn(1, 16, 64, 32)) #jakeoneijk added \ No newline at end of file diff --git a/FlashSR/AudioSR/EncoderDecoder.py b/FlashSR/AudioSR/EncoderDecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..aa54efac86ba56bccd4f43052411910c306cc78d --- /dev/null +++ b/FlashSR/AudioSR/EncoderDecoder.py @@ -0,0 +1,1010 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class UpsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=1, padding=2 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class DownsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_ + ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + downsample_time_stride4_levels=[], + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level in self.downsample_time_stride4_levels: + down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) + else: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + downsample_time_stride4_levels=[], + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # compute in_ch_mult, block_in and curr_res at lowest res + (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + # print( + # "Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape) + # ) + # ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level - 1 in self.downsample_time_stride4_levels: + up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) + else: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x).contiguous() + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x \ No newline at end of file diff --git a/FlashSR/AudioSR/Vocoder.py b/FlashSR/AudioSR/Vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8ac265761a9522c60afe3fdbece63b05f0c42484 --- /dev/null +++ b/FlashSR/AudioSR/Vocoder.py @@ -0,0 +1,167 @@ +import torch + +import FlashSR.AudioSR.hifigan as hifigan + + +def get_vocoder_config(): + return { + "resblock": "1", + "num_gpus": 6, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], + "upsample_initial_channel": 1024, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "segment_size": 8192, + "num_mels": 64, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 160, + "win_size": 1024, + "sampling_rate": 16000, + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, + } + + +def get_vocoder_config_48k(): + return { + "resblock": "1", + "num_gpus": 8, + "batch_size": 128, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [6, 5, 4, 2, 2], + "upsample_kernel_sizes": [12, 10, 8, 4, 4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3, 7, 11, 15], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]], + "segment_size": 15360, + "num_mels": 256, + "n_fft": 2048, + "hop_size": 480, + "win_size": 2048, + "sampling_rate": 48000, + "fmin": 20, + "fmax": 24000, + "fmax_for_loss": None, + "num_workers": 8, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:18273", + "world_size": 1, + }, + } + + +def get_available_checkpoint_keys(model, ckpt): + state_dict = torch.load(ckpt)["state_dict"] + current_state_dict = model.state_dict() + new_state_dict = {} + for k in state_dict.keys(): + if ( + k in current_state_dict.keys() + and current_state_dict[k].size() == state_dict[k].size() + ): + new_state_dict[k] = state_dict[k] + else: + print("==> WARNING: Skipping %s" % k) + print( + "%s out of %s keys are matched" + % (len(new_state_dict.keys()), len(state_dict.keys())) + ) + return new_state_dict + + +def get_param_num(model): + num_param = sum(param.numel() for param in model.parameters()) + return num_param + + +def torch_version_orig_mod_remove(state_dict): + new_state_dict = {} + new_state_dict["generator"] = {} + for key in state_dict["generator"].keys(): + if "_orig_mod." in key: + new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[ + "generator" + ][key] + else: + new_state_dict["generator"][key] = state_dict["generator"][key] + return new_state_dict + + +def get_vocoder(config, device, mel_bins): + name = "HiFi-GAN" + speaker = "" + if name == "MelGAN": + if speaker == "LJSpeech": + vocoder = torch.hub.load( + "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" + ) + elif speaker == "universal": + vocoder = torch.hub.load( + "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" + ) + vocoder.mel2wav.eval() + vocoder.mel2wav.to(device) + elif name == "HiFi-GAN": + if mel_bins == 64: + config = get_vocoder_config() + config = hifigan.AttrDict(config) + vocoder = hifigan.Generator_old(config) + # print("Load hifigan/g_01080000") + # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) + # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) + # ckpt = torch_version_orig_mod_remove(ckpt) + # vocoder.load_state_dict(ckpt["generator"]) + vocoder.eval() + vocoder.remove_weight_norm() + vocoder.to(device) + else: + config = get_vocoder_config_48k() + config = hifigan.AttrDict(config) + vocoder = hifigan.Generator_old(config) + # print("Load hifigan/g_01080000") + # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) + # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) + # ckpt = torch_version_orig_mod_remove(ckpt) + # vocoder.load_state_dict(ckpt["generator"]) + vocoder.eval() + vocoder.remove_weight_norm() + vocoder.to(device) + return vocoder + + +def vocoder_infer(mels, vocoder, lengths=None): + with torch.no_grad(): + wavs = vocoder(mels).squeeze(1) + + wavs = (wavs.cpu().numpy() * 32768).astype("int16") + + if lengths is not None: + wavs = wavs[:, :lengths] + + # wavs = [wav for wav in wavs] + + # for i in range(len(mels)): + # if lengths is not None: + # wavs[i] = wavs[i][: lengths[i]] + + return wavs diff --git a/FlashSR/AudioSR/args/mel_argument.yaml b/FlashSR/AudioSR/args/mel_argument.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a50510087283d519c6fe59873cbfeedc3ac25b8f --- /dev/null +++ b/FlashSR/AudioSR/args/mel_argument.yaml @@ -0,0 +1,6 @@ +nfft: 2048 +hop_size: 480 +sample_rate: 48000 +mel_size: 256 +frequency_min: 20 +frequency_max: 24000 \ No newline at end of file diff --git a/FlashSR/AudioSR/args/model_argument.yaml b/FlashSR/AudioSR/args/model_argument.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5a83d6780d29d3017995eeabfaa5973ae0e9bdae --- /dev/null +++ b/FlashSR/AudioSR/args/model_argument.yaml @@ -0,0 +1,25 @@ +batchsize: 4 +ddconfig: + attn_resolutions: [] + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 8 + double_z: true + downsample_time: false + dropout: 0.1 + in_channels: 1 + mel_bins: 256 + num_res_blocks: 2 + out_ch: 1 + resolution: 256 + z_channels: 16 +embed_dim: 16 +image_key: fbank +monitor: val/rec_loss +reload_from_ckpt: /mnt/bn/lqhaoheliu/project/audio_generation_diffusion/log/vae/vae_48k_256/ds_8_kl_1/checkpoints/ckpt-checkpoint-484999.ckpt +sampling_rate: 48000 +subband: 1 +time_shuffle: 1 diff --git a/FlashSR/AudioSR/autoencoder.py b/FlashSR/AudioSR/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0a9c502fb76f1818427736ed62d3287de3beb6d6 --- /dev/null +++ b/FlashSR/AudioSR/autoencoder.py @@ -0,0 +1,370 @@ +import os +import soundfile as sf +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from FlashSR.AudioSR.EncoderDecoder import Encoder, Decoder +from FlashSR.AudioSR.Vocoder import get_vocoder + + +class AutoencoderKL(nn.Module): + def __init__( + self, + ddconfig=None, + lossconfig=None, + batchsize=None, + embed_dim=None, + time_shuffle=1, + subband=1, + sampling_rate=16000, + ckpt_path=None, + reload_from_ckpt=None, + ignore_keys=[], + image_key="fbank", + colorize_nlabels=None, + monitor=None, + base_learning_rate=1e-5, + ): + super().__init__() + self.automatic_optimization = False + assert ( + "mel_bins" in ddconfig.keys() + ), "mel_bins is not specified in the Autoencoder config" + num_mel = ddconfig["mel_bins"] + self.image_key = image_key + self.sampling_rate = sampling_rate + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + + self.loss = None + self.subband = int(subband) + + if self.subband > 1: + print("Use subband decomposition %s" % self.subband) + + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + + if self.image_key == "fbank": + self.vocoder = get_vocoder(None, "cpu", num_mel) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.learning_rate = float(base_learning_rate) + # print("Initial learning rate %s" % self.learning_rate) + + self.time_shuffle = time_shuffle + self.reload_from_ckpt = reload_from_ckpt + self.reloaded = False + self.mean, self.std = None, None + + self.feature_cache = None + self.flag_first_run = True + self.train_step = 0 + + self.logger_save_dir = None + self.logger_exp_name = None + + def get_log_dir(self): + if self.logger_save_dir is None and self.logger_exp_name is None: + return os.path.join(self.logger.save_dir, self.logger._project) + else: + return os.path.join(self.logger_save_dir, self.logger_exp_name) + + def set_log_dir(self, save_dir, exp_name): + self.logger_save_dir = save_dir + self.logger_exp_name = exp_name + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + # x = self.time_shuffle_operation(x) + # x = self.freq_split_subband(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + # bs, ch, shuffled_timesteps, fbins = dec.size() + # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins) + # dec = self.freq_merge_subband(dec) + return dec + + def decode_to_waveform(self, dec): + from audiosr.utilities.model import vocoder_infer + + if self.image_key == "fbank": + dec = dec.squeeze(1).permute(0, 2, 1) + wav_reconstruction = vocoder_infer(dec, self.vocoder) + elif self.image_key == "stft": + dec = dec.squeeze(1).permute(0, 2, 1) + wav_reconstruction = self.wave_decoder(dec) + return wav_reconstruction + + def visualize_latent(self, input): + import matplotlib.pyplot as plt + + # for i in range(10): + # zero_input = torch.zeros_like(input) - 11.59 + # zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59 + + # posterior = self.encode(zero_input) + # latent = posterior.sample() + # avg_latent = torch.mean(latent, dim=1)[0] + # plt.imshow(avg_latent.cpu().detach().numpy().T) + # plt.savefig("%s.png" % i) + # plt.close() + + np.save("input.npy", input.cpu().detach().numpy()) + # zero_input = torch.zeros_like(input) - 11.59 + time_input = input.clone() + time_input[:, :, :, :32] *= 0 + time_input[:, :, :, :32] -= 11.59 + + np.save("time_input.npy", time_input.cpu().detach().numpy()) + + posterior = self.encode(time_input) + latent = posterior.sample() + np.save("time_latent.npy", latent.cpu().detach().numpy()) + avg_latent = torch.mean(latent, dim=1) + for i in range(avg_latent.size(0)): + plt.imshow(avg_latent[i].cpu().detach().numpy().T) + plt.savefig("freq_%s.png" % i) + plt.close() + + freq_input = input.clone() + freq_input[:, :, :512, :] *= 0 + freq_input[:, :, :512, :] -= 11.59 + + np.save("freq_input.npy", freq_input.cpu().detach().numpy()) + + posterior = self.encode(freq_input) + latent = posterior.sample() + np.save("freq_latent.npy", latent.cpu().detach().numpy()) + avg_latent = torch.mean(latent, dim=1) + for i in range(avg_latent.size(0)): + plt.imshow(avg_latent[i].cpu().detach().numpy().T) + plt.savefig("time_%s.png" % i) + plt.close() + + def get_input(self, batch): + fname, text, label_indices, waveform, stft, fbank = ( + batch["fname"], + batch["text"], + batch["label_vector"], + batch["waveform"], + batch["stft"], + batch["log_mel_spec"], + ) + # if(self.time_shuffle != 1): + # if(fbank.size(1) % self.time_shuffle != 0): + # pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle) + # fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len)) + + ret = {} + + ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( + fbank.unsqueeze(1), + stft.unsqueeze(1), + fname, + waveform.unsqueeze(1), + ) + + return ret + + def save_wave(self, batch_wav, fname, save_dir): + os.makedirs(save_dir, exist_ok=True) + + for wav, name in zip(batch_wav, fname): + name = os.path.basename(name) + + sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): + log = dict() + x = batch.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + log["samples"] = self.decode(posterior.sample()) + log["reconstructions"] = xrec + + log["inputs"] = x + wavs = self._log_img(log, train=train, index=0, waveform=waveform) + return wavs + + def _log_img(self, log, train=True, index=0, waveform=None): + images_input = self.tensor2numpy(log["inputs"][index, 0]).T + images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T + images_samples = self.tensor2numpy(log["samples"][index, 0]).T + + if train: + name = "train" + else: + name = "val" + + if self.logger is not None: + self.logger.log_image( + "img_%s" % name, + [images_input, images_reconstruct, images_samples], + caption=["input", "reconstruct", "samples"], + ) + + inputs, reconstructions, samples = ( + log["inputs"], + log["reconstructions"], + log["samples"], + ) + + if self.image_key == "fbank": + wav_original, wav_prediction = synth_one_sample( + inputs[index], + reconstructions[index], + labels="validation", + vocoder=self.vocoder, + ) + wav_original, wav_samples = synth_one_sample( + inputs[index], samples[index], labels="validation", vocoder=self.vocoder + ) + wav_original, wav_samples, wav_prediction = ( + wav_original[0], + wav_samples[0], + wav_prediction[0], + ) + elif self.image_key == "stft": + wav_prediction = ( + self.decode_to_waveform(reconstructions)[index, 0] + .cpu() + .detach() + .numpy() + ) + wav_samples = ( + self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() + ) + wav_original = waveform[index, 0].cpu().detach().numpy() + + if self.logger is not None: + self.logger.experiment.log( + { + "original_%s" + % name: wandb.Audio( + wav_original, caption="original", sample_rate=self.sampling_rate + ), + "reconstruct_%s" + % name: wandb.Audio( + wav_prediction, + caption="reconstruct", + sample_rate=self.sampling_rate, + ), + "samples_%s" + % name: wandb.Audio( + wav_samples, caption="samples", sample_rate=self.sampling_rate + ), + } + ) + + return wav_original, wav_prediction, wav_samples + + def tensor2numpy(self, tensor): + return tensor.cpu().detach().numpy() + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean \ No newline at end of file diff --git a/FlashSR/AudioSR/hifigan/LICENSE b/FlashSR/AudioSR/hifigan/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..5afae394d6b37da0e12ba6b290d2512687f421ac --- /dev/null +++ b/FlashSR/AudioSR/hifigan/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Jungil Kong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/FlashSR/AudioSR/hifigan/__init__.py b/FlashSR/AudioSR/hifigan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34e055557bf2ecb457376663b67390543c71fb1f --- /dev/null +++ b/FlashSR/AudioSR/hifigan/__init__.py @@ -0,0 +1,8 @@ +from .models_v2 import Generator +from .models import Generator as Generator_old + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self diff --git a/FlashSR/AudioSR/hifigan/models.py b/FlashSR/AudioSR/hifigan/models.py new file mode 100644 index 0000000000000000000000000000000000000000..c4382cc39de0463f9b7c0f33f037dbc233e7cb36 --- /dev/null +++ b/FlashSR/AudioSR/hifigan/models.py @@ -0,0 +1,174 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + # print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/FlashSR/AudioSR/hifigan/models_v2.py b/FlashSR/AudioSR/hifigan/models_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..27a2df6b54bdd3a5b259645442624800ac0e8afe --- /dev/null +++ b/FlashSR/AudioSR/hifigan/models_v2.py @@ -0,0 +1,395 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock1 if h.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + u * 2, + u, + padding=u // 2 + u % 2, + output_padding=u % 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + # import ipdb; ipdb.set_trace() + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + # print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +################################################################################################## + +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from torch.nn import Conv1d, ConvTranspose1d +# from torch.nn.utils import weight_norm, remove_weight_norm + +# LRELU_SLOPE = 0.1 + + +# def init_weights(m, mean=0.0, std=0.01): +# classname = m.__class__.__name__ +# if classname.find("Conv") != -1: +# m.weight.data.normal_(mean, std) + + +# def get_padding(kernel_size, dilation=1): +# return int((kernel_size * dilation - dilation) / 2) + + +# class ResBlock(torch.nn.Module): +# def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): +# super(ResBlock, self).__init__() +# self.h = h +# self.convs1 = nn.ModuleList( +# [ +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=dilation[0], +# padding=get_padding(kernel_size, dilation[0]), +# ) +# ), +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=dilation[1], +# padding=get_padding(kernel_size, dilation[1]), +# ) +# ), +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=dilation[2], +# padding=get_padding(kernel_size, dilation[2]), +# ) +# ), +# ] +# ) +# self.convs1.apply(init_weights) + +# self.convs2 = nn.ModuleList( +# [ +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=1, +# padding=get_padding(kernel_size, 1), +# ) +# ), +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=1, +# padding=get_padding(kernel_size, 1), +# ) +# ), +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=1, +# padding=get_padding(kernel_size, 1), +# ) +# ), +# ] +# ) +# self.convs2.apply(init_weights) + +# def forward(self, x): +# for c1, c2 in zip(self.convs1, self.convs2): +# xt = F.leaky_relu(x, LRELU_SLOPE) +# xt = c1(xt) +# xt = F.leaky_relu(xt, LRELU_SLOPE) +# xt = c2(xt) +# x = xt + x +# return x + +# def remove_weight_norm(self): +# for l in self.convs1: +# remove_weight_norm(l) +# for l in self.convs2: +# remove_weight_norm(l) + +# class Generator(torch.nn.Module): +# def __init__(self, h): +# super(Generator, self).__init__() +# self.h = h +# self.num_kernels = len(h.resblock_kernel_sizes) +# self.num_upsamples = len(h.upsample_rates) +# self.conv_pre = weight_norm( +# Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) +# ) +# resblock = ResBlock + +# self.ups = nn.ModuleList() +# for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): +# self.ups.append( +# weight_norm( +# ConvTranspose1d( +# h.upsample_initial_channel // (2**i), +# h.upsample_initial_channel // (2 ** (i + 1)), +# k, +# u, +# padding=(k - u) // 2, +# ) +# ) +# ) + +# self.resblocks = nn.ModuleList() +# for i in range(len(self.ups)): +# ch = h.upsample_initial_channel // (2 ** (i + 1)) +# for j, (k, d) in enumerate( +# zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) +# ): +# self.resblocks.append(resblock(h, ch, k, d)) + +# self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) +# self.ups.apply(init_weights) +# self.conv_post.apply(init_weights) + +# def forward(self, x): +# x = self.conv_pre(x) +# for i in range(self.num_upsamples): +# x = F.leaky_relu(x, LRELU_SLOPE) +# x = self.ups[i](x) +# xs = None +# for j in range(self.num_kernels): +# if xs is None: +# xs = self.resblocks[i * self.num_kernels + j](x) +# else: +# xs += self.resblocks[i * self.num_kernels + j](x) +# x = xs / self.num_kernels +# x = F.leaky_relu(x) +# x = self.conv_post(x) +# x = torch.tanh(x) + +# return x + +# def remove_weight_norm(self): +# print("Removing weight norm...") +# for l in self.ups: +# remove_weight_norm(l) +# for l in self.resblocks: +# l.remove_weight_norm() +# remove_weight_norm(self.conv_pre) +# remove_weight_norm(self.conv_post) diff --git a/FlashSR/AudioSR/latent_diffusion/__init__.py b/FlashSR/AudioSR/latent_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/FlashSR/AudioSR/latent_diffusion/modules/attention.py b/FlashSR/AudioSR/latent_diffusion/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1132144bd7699b8c91800f84cc05d5e061ca3b0e --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/attention.py @@ -0,0 +1,467 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from FlashSR.AudioSR.latent_diffusion.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +# class CrossAttention(nn.Module): +# """ +# ### Cross Attention Layer +# This falls-back to self-attention when conditional embeddings are not specified. +# """ + +# use_flash_attention: bool = True + +# # use_flash_attention: bool = False +# def __init__( +# self, +# query_dim, +# context_dim=None, +# heads=8, +# dim_head=64, +# dropout=0.0, +# is_inplace: bool = True, +# ): +# # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True): +# """ +# :param d_model: is the input embedding size +# :param n_heads: is the number of attention heads +# :param d_head: is the size of a attention head +# :param d_cond: is the size of the conditional embeddings +# :param is_inplace: specifies whether to perform the attention softmax computation inplace to +# save memory +# """ +# super().__init__() + +# self.is_inplace = is_inplace +# self.n_heads = heads +# self.d_head = dim_head + +# # Attention scaling factor +# self.scale = dim_head**-0.5 + +# # The normal self-attention layer +# if context_dim is None: +# context_dim = query_dim + +# # Query, key and value mappings +# d_attn = dim_head * heads +# self.to_q = nn.Linear(query_dim, d_attn, bias=False) +# self.to_k = nn.Linear(context_dim, d_attn, bias=False) +# self.to_v = nn.Linear(context_dim, d_attn, bias=False) + +# # Final linear layer +# self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout)) + +# # Setup [flash attention](https://github.com/HazyResearch/flash-attention). +# # Flash attention is only used if it's installed +# # and `CrossAttention.use_flash_attention` is set to `True`. +# try: +# # You can install flash attention by cloning their Github repo, +# # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention) +# # and then running `python setup.py install` +# from flash_attn.flash_attention import FlashAttention + +# self.flash = FlashAttention() +# # Set the scale for scaled dot-product attention. +# self.flash.softmax_scale = self.scale +# # Set to `None` if it's not installed +# except ImportError: +# self.flash = None + +# def forward(self, x, context=None, mask=None): +# """ +# :param x: are the input embeddings of shape `[batch_size, height * width, d_model]` +# :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` +# """ + +# # If `cond` is `None` we perform self attention +# has_cond = context is not None +# if not has_cond: +# context = x + +# # Get query, key and value vectors +# q = self.to_q(x) +# k = self.to_k(context) +# v = self.to_v(context) + +# # Use flash attention if it's available and the head size is less than or equal to `128` +# if ( +# CrossAttention.use_flash_attention +# and self.flash is not None +# and not has_cond +# and self.d_head <= 128 +# ): +# return self.flash_attention(q, k, v) +# # Otherwise, fallback to normal attention +# else: +# return self.normal_attention(q, k, v) + +# def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): +# """ +# #### Flash Attention +# :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# """ + +# # Get batch size and number of elements along sequence axis (`width * height`) +# batch_size, seq_len, _ = q.shape + +# # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of +# # shape `[batch_size, seq_len, 3, n_heads * d_head]` +# qkv = torch.stack((q, k, v), dim=2) +# # Split the heads +# qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head) + +# # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to +# # fit this size. +# if self.d_head <= 32: +# pad = 32 - self.d_head +# elif self.d_head <= 64: +# pad = 64 - self.d_head +# elif self.d_head <= 128: +# pad = 128 - self.d_head +# else: +# raise ValueError(f"Head size ${self.d_head} too large for Flash Attention") + +# # Pad the heads +# if pad: +# qkv = torch.cat( +# (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1 +# ) + +# # Compute attention +# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ +# # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]` +# # TODO here I add the dtype changing +# out, _ = self.flash(qkv.type(torch.float16)) +# # Truncate the extra head size +# out = out[:, :, :, : self.d_head].float() +# # Reshape to `[batch_size, seq_len, n_heads * d_head]` +# out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head) + +# # Map to `[batch_size, height * width, d_model]` with a linear layer +# return self.to_out(out) + +# def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): +# """ +# #### Normal Attention + +# :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# """ + +# # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]` +# q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32] +# k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32] +# v = v.view(*v.shape[:2], self.n_heads, -1) + +# # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$ +# attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale + +# # Compute softmax +# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$ +# if self.is_inplace: +# half = attn.shape[0] // 2 +# attn[half:] = attn[half:].softmax(dim=-1) +# attn[:half] = attn[:half].softmax(dim=-1) +# else: +# attn = attn.softmax(dim=-1) + +# # Compute attention output +# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ +# # attn: [bs, 20, 64, 1] +# # v: [bs, 1, 20, 32] +# out = torch.einsum("bhij,bjhd->bihd", attn, v) +# # Reshape to `[batch_size, height * width, n_heads * d_head]` +# out = out.reshape(*out.shape[:2], -1) +# # Map to `[batch_size, height * width, d_model]` with a linear layer +# return self.to_out(out) + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~(mask == 1), max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None, mask=None): + if context is None: + return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) + else: + return checkpoint( + self._forward, (x, context, mask), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None, mask=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context, mask=mask) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + ): + super().__init__() + + context_dim = context_dim + + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, x, context=None, mask=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + for block in self.transformer_blocks: + x = block(x, context=context, mask=mask) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/AudioMAE.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/AudioMAE.py new file mode 100644 index 0000000000000000000000000000000000000000..855f3782a35542083013df256815564df8d34655 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/AudioMAE.py @@ -0,0 +1,149 @@ +""" +Reference Repo: https://github.com/facebookresearch/AudioMAE +""" + +import torch +import torch.nn as nn +#from timm.models.layers import to_2tuple +import audiosr.latent_diffusion.modules.audiomae.models_vit as models_vit +import audiosr.latent_diffusion.modules.audiomae.models_mae as models_mae + +# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128)) + + +class PatchEmbed_new(nn.Module): + """Flexible Image to Patch Embedding""" + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + + self.img_size = img_size + self.patch_size = patch_size + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride + ) # with overlapped patches + # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) + # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w + self.patch_hw = (h, w) + self.num_patches = h * w + + def get_output_shape(self, img_size): + # todo: don't be lazy.. + return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class AudioMAE(nn.Module): + """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)""" + + def __init__( + self, + ): + super().__init__() + model = models_vit.__dict__["vit_base_patch16"]( + num_classes=527, + drop_path_rate=0.1, + global_pool=True, + mask_2d=True, + use_custom_patch=False, + ) + + img_size = (1024, 128) + emb_dim = 768 + + model.patch_embed = PatchEmbed_new( + img_size=img_size, + patch_size=(16, 16), + in_chans=1, + embed_dim=emb_dim, + stride=16, + ) + num_patches = model.patch_embed.num_patches + # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8 + model.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False + ) # fixed sin-cos embedding + + # checkpoint_path = '/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth' + # checkpoint = torch.load(checkpoint_path, map_location='cpu') + # msg = model.load_state_dict(checkpoint['model'], strict=False) + # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') + + self.model = model + + def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0): + """ + x: mel fbank [Batch, 1, T, F] + mask_t_prob: 'T masking ratio (percentage of removed patches).' + mask_f_prob: 'F masking ratio (percentage of removed patches).' + """ + return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob) + + +class Vanilla_AudioMAE(nn.Module): + """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)""" + + def __init__( + self, + ): + super().__init__() + model = models_mae.__dict__["mae_vit_base_patch16"]( + in_chans=1, audio_exp=True, img_size=(1024, 128) + ) + + # checkpoint_path = '/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth' + # checkpoint = torch.load(checkpoint_path, map_location='cpu') + # msg = model.load_state_dict(checkpoint['model'], strict=False) + + # Skip the missing keys of decoder modules (not required) + # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') + + self.model = model.eval() + + def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False): + """ + x: mel fbank [Batch, 1, 1024 (T), 128 (F)] + mask_ratio: 'masking ratio (percentage of removed patches).' + """ + with torch.no_grad(): + # embed: [B, 513, 768] for mask_ratio=0.0 + if no_mask: + if no_average: + raise RuntimeError("This function is deprecated") + embed = self.model.forward_encoder_no_random_mask_no_average( + x + ) # mask_ratio + else: + embed = self.model.forward_encoder_no_mask(x) # mask_ratio + else: + raise RuntimeError("This function is deprecated") + embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio) + return embed + + +if __name__ == "__main__": + model = Vanilla_AudioMAE().cuda() + input = torch.randn(4, 1, 1024, 128).cuda() + print("The first run") + embed = model(input, mask_ratio=0.0, no_mask=True) + print(embed) + print("The second run") + embed = model(input, mask_ratio=0.0) + print(embed) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/__init__.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_mae.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_mae.py new file mode 100644 index 0000000000000000000000000000000000000000..84395f2ddfd31ad81a420121dd24ca73b10f59e9 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_mae.py @@ -0,0 +1,613 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import torch +import torch.nn as nn + +from timm.models.vision_transformer import Block +from audiosr.latent_diffusion.modules.audiomae.util.pos_embed import ( + get_2d_sincos_pos_embed, + get_2d_sincos_pos_embed_flexible, +) +from audiosr.latent_diffusion.modules.audiomae.util.patch_embed import ( + PatchEmbed_new, + PatchEmbed_org, +) + + +class MaskedAutoencoderViT(nn.Module): + """Masked Autoencoder with VisionTransformer backbone""" + + def __init__( + self, + img_size=224, + patch_size=16, + stride=10, + in_chans=3, + embed_dim=1024, + depth=24, + num_heads=16, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4.0, + norm_layer=nn.LayerNorm, + norm_pix_loss=False, + audio_exp=False, + alpha=0.0, + temperature=0.2, + mode=0, + contextual_depth=8, + use_custom_patch=False, + split_pos=False, + pos_trainable=False, + use_nce=False, + beta=4.0, + decoder_mode=0, + mask_t_prob=0.6, + mask_f_prob=0.5, + mask_2d=False, + epoch=0, + no_shift=False, + ): + super().__init__() + + self.audio_exp = audio_exp + self.embed_dim = embed_dim + self.decoder_embed_dim = decoder_embed_dim + # -------------------------------------------------------------------------- + # MAE encoder specifics + if use_custom_patch: + print( + f"Use custom patch_emb with patch size: {patch_size}, stride: {stride}" + ) + self.patch_embed = PatchEmbed_new( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + stride=stride, + ) + else: + self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim) + self.use_custom_patch = use_custom_patch + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + # self.split_pos = split_pos # not useful + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable + ) # fixed sin-cos embedding + + self.encoder_depth = depth + self.contextual_depth = contextual_depth + self.blocks = nn.ModuleList( + [ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) # qk_scale=None + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) + + # -------------------------------------------------------------------------- + # MAE decoder specifics + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, decoder_embed_dim), + requires_grad=pos_trainable, + ) # fixed sin-cos embedding + + self.no_shift = no_shift + + self.decoder_mode = decoder_mode + if ( + self.use_custom_patch + ): # overlapped patches as in AST. Similar performance yet compute heavy + window_size = (6, 6) + feat_size = (102, 12) + else: + window_size = (4, 4) + feat_size = (64, 8) + if self.decoder_mode == 1: + decoder_modules = [] + for index in range(16): + if self.no_shift: + shift_size = (0, 0) + else: + if (index % 2) == 0: + shift_size = (0, 0) + else: + shift_size = (2, 0) + # shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]) + decoder_modules.append( + SwinTransformerBlock( + dim=decoder_embed_dim, + num_heads=16, + feat_size=feat_size, + window_size=window_size, + shift_size=shift_size, + mlp_ratio=mlp_ratio, + drop=0.0, + drop_attn=0.0, + drop_path=0.0, + extra_norm=False, + sequential_attn=False, + norm_layer=norm_layer, # nn.LayerNorm, + ) + ) + self.decoder_blocks = nn.ModuleList(decoder_modules) + else: + # Transfomer + self.decoder_blocks = nn.ModuleList( + [ + Block( + decoder_embed_dim, + decoder_num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) # qk_scale=None, + for i in range(decoder_depth) + ] + ) + + self.decoder_norm = norm_layer(decoder_embed_dim) + self.decoder_pred = nn.Linear( + decoder_embed_dim, patch_size**2 * in_chans, bias=True + ) # decoder to patch + + # -------------------------------------------------------------------------- + + self.norm_pix_loss = norm_pix_loss + + self.patch_size = patch_size + self.stride = stride + + # audio exps + self.alpha = alpha + self.T = temperature + self.mode = mode + self.use_nce = use_nce + self.beta = beta + + self.log_softmax = nn.LogSoftmax(dim=-1) + + self.mask_t_prob = mask_t_prob + self.mask_f_prob = mask_f_prob + self.mask_2d = mask_2d + + self.epoch = epoch + + self.initialize_weights() + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + if self.audio_exp: + pos_embed = get_2d_sincos_pos_embed_flexible( + self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True + ) + else: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.patch_embed.num_patches**0.5), + cls_token=True, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + if self.audio_exp: + decoder_pos_embed = get_2d_sincos_pos_embed_flexible( + self.decoder_pos_embed.shape[-1], + self.patch_embed.patch_hw, + cls_token=True, + ) + else: + decoder_pos_embed = get_2d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], + int(self.patch_embed.num_patches**0.5), + cls_token=True, + ) + self.decoder_pos_embed.data.copy_( + torch.from_numpy(decoder_pos_embed).float().unsqueeze(0) + ) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.cls_token, std=0.02) + torch.nn.init.normal_(self.mask_token, std=0.02) + + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + L = (H/p)*(W/p) + """ + p = self.patch_embed.patch_size[0] + # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + if self.audio_exp: + if self.use_custom_patch: # overlapped patch + h, w = self.patch_embed.patch_hw + # todo: fixed h/w patch size and stride size. Make hw custom in the future + x = imgs.unfold(2, self.patch_size, self.stride).unfold( + 3, self.patch_size, self.stride + ) # n,1,H,W -> n,1,h,w,p,p + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) + # x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) + # x = torch.einsum('nchpwq->nhwpqc', x) + # x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) + else: + h = imgs.shape[2] // p + w = imgs.shape[3] // p + # h,w = self.patch_embed.patch_hw + x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) + x = torch.einsum("nchpwq->nhwpqc", x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) + else: + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum("nchpwq->nhwpqc", x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + specs: (N, 1, H, W) + """ + p = self.patch_embed.patch_size[0] + h = 1024 // p + w = 128 // p + x = x.reshape(shape=(x.shape[0], h, w, p, p, 1)) + x = torch.einsum("nhwpqc->nchpwq", x) + specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p)) + return specs + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def random_masking_2d(self, x, mask_t_prob, mask_f_prob): + """ + 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + if self.use_custom_patch: # overlapped patch + T = 101 + F = 12 + else: + T = 64 + F = 8 + # x = x.reshape(N, T, F, D) + len_keep_t = int(T * (1 - mask_t_prob)) + len_keep_f = int(F * (1 - mask_f_prob)) + + # noise for mask in time + noise_t = torch.rand(N, T, device=x.device) # noise in [0, 1] + # sort noise for each sample aling time + ids_shuffle_t = torch.argsort( + noise_t, dim=1 + ) # ascend: small is keep, large is remove + ids_restore_t = torch.argsort(ids_shuffle_t, dim=1) + ids_keep_t = ids_shuffle_t[:, :len_keep_t] + # noise mask in freq + noise_f = torch.rand(N, F, device=x.device) # noise in [0, 1] + ids_shuffle_f = torch.argsort( + noise_f, dim=1 + ) # ascend: small is keep, large is remove + ids_restore_f = torch.argsort(ids_shuffle_f, dim=1) + ids_keep_f = ids_shuffle_f[:, :len_keep_f] # + + # generate the binary mask: 0 is keep, 1 is remove + # mask in freq + mask_f = torch.ones(N, F, device=x.device) + mask_f[:, :len_keep_f] = 0 + mask_f = ( + torch.gather(mask_f, dim=1, index=ids_restore_f) + .unsqueeze(1) + .repeat(1, T, 1) + ) # N,T,F + # mask in time + mask_t = torch.ones(N, T, device=x.device) + mask_t[:, :len_keep_t] = 0 + mask_t = ( + torch.gather(mask_t, dim=1, index=ids_restore_t) + .unsqueeze(1) + .repeat(1, F, 1) + .permute(0, 2, 1) + ) # N,T,F + mask = 1 - (1 - mask_t) * (1 - mask_f) # N, T, F + + # get masked x + id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device) + id2res = id2res + 999 * mask # add a large value for masked elements + id2res2 = torch.argsort(id2res.flatten(start_dim=1)) + ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + ids_restore = torch.argsort(id2res2.flatten(start_dim=1)) + mask = mask.flatten(start_dim=1) + + return x_masked, mask, ids_restore + + def forward_encoder(self, x, mask_ratio, mask_2d=False): + # embed patches + x = self.patch_embed(x) + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + if mask_2d: + x, mask, ids_restore = self.random_masking_2d( + x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob + ) + else: + x, mask, ids_restore = self.random_masking(x, mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x, mask, ids_restore, None + + def forward_encoder_no_random_mask_no_average(self, x): + # embed patches + x = self.patch_embed(x) + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + # if mask_2d: + # x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob) + # else: + # x, mask, ids_restore = self.random_masking(x, mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x + + def forward_encoder_no_mask(self, x): + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + # x, mask, ids_restore = self.random_masking(x, mask_ratio) + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + contextual_embs = [] + for n, blk in enumerate(self.blocks): + x = blk(x) + if n > self.contextual_depth: + contextual_embs.append(self.norm(x)) + # x = self.norm(x) + contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0) + + return contextual_emb + + def forward_decoder(self, x, ids_restore): + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1 + ) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather( + x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) + ) # unshuffle + x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token + + # add pos embed + x = x + self.decoder_pos_embed + + if self.decoder_mode != 0: + B, L, D = x.shape + x = x[:, 1:, :] + if self.use_custom_patch: + x = x.reshape(B, 101, 12, D) + x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1) # hack + x = x.reshape(B, 1224, D) + if self.decoder_mode > 3: # mvit + x = self.decoder_blocks(x) + else: + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + pred = self.decoder_pred(x) + + # remove cls token + if self.decoder_mode != 0: + if self.use_custom_patch: + pred = pred.reshape(B, 102, 12, 256) + pred = pred[:, :101, :, :] + pred = pred.reshape(B, 1212, 256) + else: + pred = pred + else: + pred = pred[:, 1:, :] + return pred, None, None # emb, emb_pixel + + def forward_loss(self, imgs, pred, mask, norm_pix_loss=False): + """ + imgs: [N, 3, H, W] + pred: [N, L, p*p*3] + mask: [N, L], 0 is keep, 1 is remove, + """ + target = self.patchify(imgs) + if norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.0e-6) ** 0.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + return loss + + def forward(self, imgs, mask_ratio=0.8): + emb_enc, mask, ids_restore, _ = self.forward_encoder( + imgs, mask_ratio, mask_2d=self.mask_2d + ) + pred, _, _ = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3] + loss_recon = self.forward_loss( + imgs, pred, mask, norm_pix_loss=self.norm_pix_loss + ) + loss_contrastive = torch.FloatTensor([0.0]).cuda() + return loss_recon, pred, mask, loss_contrastive + + +def mae_vit_small_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + decoder_embed_dim=512, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def mae_vit_base_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + decoder_embed_dim=512, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def mae_vit_large_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + decoder_embed_dim=512, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def mae_vit_huge_patch14_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + decoder_embed_dim=512, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +# set recommended archs +mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_vit.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..cb37adbc16cfb9a232493c473c9400f199655b6c --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_vit.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import torch +import torch.nn as nn +import timm.models.vision_transformer + + +class VisionTransformer(timm.models.vision_transformer.VisionTransformer): + """Vision Transformer with support for global average pooling""" + + def __init__( + self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs + ): + super(VisionTransformer, self).__init__(**kwargs) + + self.global_pool = global_pool + if self.global_pool: + norm_layer = kwargs["norm_layer"] + embed_dim = kwargs["embed_dim"] + self.fc_norm = norm_layer(embed_dim) + del self.norm # remove the original norm + self.mask_2d = mask_2d + self.use_custom_patch = use_custom_patch + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed[:, 1:, :] + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + if self.global_pool: + x = x[:, 1:, :].mean(dim=1) # global pool without cls token + outcome = self.fc_norm(x) + else: + x = self.norm(x) + outcome = x[:, 0] + + return outcome + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def random_masking_2d(self, x, mask_t_prob, mask_f_prob): + """ + 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + + N, L, D = x.shape # batch, length, dim + if self.use_custom_patch: + # # for AS + T = 101 # 64,101 + F = 12 # 8,12 + # # for ESC + # T=50 + # F=12 + # for SPC + # T=12 + # F=12 + else: + # ## for AS + T = 64 + F = 8 + # ## for ESC + # T=32 + # F=8 + ## for SPC + # T=8 + # F=8 + + # mask T + x = x.reshape(N, T, F, D) + len_keep_T = int(T * (1 - mask_t_prob)) + noise = torch.rand(N, T, device=x.device) # noise in [0, 1] + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_keep = ids_shuffle[:, :len_keep_T] + index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D) + # x_masked = torch.gather(x, dim=1, index=index) + # x_masked = x_masked.reshape(N,len_keep_T*F,D) + x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D + + # mask F + # x = x.reshape(N, T, F, D) + x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D + len_keep_F = int(F * (1 - mask_f_prob)) + noise = torch.rand(N, F, device=x.device) # noise in [0, 1] + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_keep = ids_shuffle[:, :len_keep_F] + # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D) + index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D) + x_masked = torch.gather(x, dim=1, index=index) + x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D + # x_masked = x_masked.reshape(N,len_keep*T,D) + x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D) + + return x_masked, None, None + + def forward_features_mask(self, x, mask_t_prob, mask_f_prob): + B = x.shape[0] # 4,1,1024,128 + x = self.patch_embed(x) # 4, 512, 768 + + x = x + self.pos_embed[:, 1:, :] + if self.random_masking_2d: + x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob) + else: + x, mask, ids_restore = self.random_masking(x, mask_t_prob) + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = self.pos_drop(x) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + + if self.global_pool: + x = x[:, 1:, :].mean(dim=1) # global pool without cls token + outcome = self.fc_norm(x) + else: + x = self.norm(x) + outcome = x[:, 0] + + return outcome + + # overwrite original timm + def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0): + if mask_t_prob > 0.0 or mask_f_prob > 0.0: + x = self.forward_features_mask( + x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob + ) + else: + x = self.forward_features(x) + x = self.head(x) + return x + + +def vit_small_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model + + +def vit_base_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model + + +def vit_large_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model + + +def vit_huge_patch14(**kwargs): + model = VisionTransformer( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/crop.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/crop.py new file mode 100644 index 0000000000000000000000000000000000000000..525e3c783c3d348e593dc89c2b5fb8520918e9ea --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/crop.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch + +from torchvision import transforms +from torchvision.transforms import functional as F + + +class RandomResizedCrop(transforms.RandomResizedCrop): + """ + RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. + This may lead to results different with torchvision's version. + Following BYOL's TF code: + https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 + """ + + @staticmethod + def get_params(img, scale, ratio): + width, height = F._get_image_size(img) + area = height * width + + target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + log_ratio = torch.log(torch.tensor(ratio)) + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) + ).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + w = min(w, width) + h = min(h, height) + + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + + return i, j, h, w diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/datasets.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..b90f89a7d5f78c31bc9113dd88b632b0c234f10a --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/datasets.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +import os +import PIL + +from torchvision import datasets, transforms + +from timm.data import create_transform +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def build_dataset(is_train, args): + transform = build_transform(is_train, args) + + root = os.path.join(args.data_path, "train" if is_train else "val") + dataset = datasets.ImageFolder(root, transform=transform) + + print(dataset) + + return dataset + + +def build_transform(is_train, args): + mean = IMAGENET_DEFAULT_MEAN + std = IMAGENET_DEFAULT_STD + # train transform + if is_train: + # this should always dispatch to transforms_imagenet_train + transform = create_transform( + input_size=args.input_size, + is_training=True, + color_jitter=args.color_jitter, + auto_augment=args.aa, + interpolation="bicubic", + re_prob=args.reprob, + re_mode=args.remode, + re_count=args.recount, + mean=mean, + std=std, + ) + return transform + + # eval transform + t = [] + if args.input_size <= 224: + crop_pct = 224 / 256 + else: + crop_pct = 1.0 + size = int(args.input_size / crop_pct) + t.append( + transforms.Resize( + size, interpolation=PIL.Image.BICUBIC + ), # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(args.input_size)) + + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(mean, std)) + return transforms.Compose(t) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lars.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lars.py new file mode 100644 index 0000000000000000000000000000000000000000..fc43923d22cf2c9af4ae9166612c3f3477faf254 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lars.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# LARS optimizer, implementation from MoCo v3: +# https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- + +import torch + + +class LARS(torch.optim.Optimizer): + """ + LARS optimizer, no rate scaling or weight decay for parameters <= 1D. + """ + + def __init__( + self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001 + ): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + trust_coefficient=trust_coefficient, + ) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self): + for g in self.param_groups: + for p in g["params"]: + dp = p.grad + + if dp is None: + continue + + if p.ndim > 1: # if not normalization gamma/beta or bias + dp = dp.add(p, alpha=g["weight_decay"]) + param_norm = torch.norm(p) + update_norm = torch.norm(dp) + one = torch.ones_like(param_norm) + q = torch.where( + param_norm > 0.0, + torch.where( + update_norm > 0, + (g["trust_coefficient"] * param_norm / update_norm), + one, + ), + one, + ) + dp = dp.mul(q) + + param_state = self.state[p] + if "mu" not in param_state: + param_state["mu"] = torch.zeros_like(p) + mu = param_state["mu"] + mu.mul_(g["momentum"]).add_(dp) + p.add_(mu, alpha=-g["lr"]) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_decay.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_decay.py new file mode 100644 index 0000000000000000000000000000000000000000..e90ed69d7b8d019dbf5d90571541668e2bd8efe8 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_decay.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ELECTRA https://github.com/google-research/electra +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + + +def param_groups_lrd( + model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75 +): + """ + Parameter groups for layer-wise lr decay + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 + """ + param_group_names = {} + param_groups = {} + + num_layers = len(model.blocks) + 1 + + layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + + # no decay: all 1D parameters and model specific ones + if p.ndim == 1 or n in no_weight_decay_list: + g_decay = "no_decay" + this_decay = 0.0 + else: + g_decay = "decay" + this_decay = weight_decay + + layer_id = get_layer_id_for_vit(n, num_layers) + group_name = "layer_%d_%s" % (layer_id, g_decay) + + if group_name not in param_group_names: + this_scale = layer_scales[layer_id] + + param_group_names[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + param_groups[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + + param_group_names[group_name]["params"].append(n) + param_groups[group_name]["params"].append(p) + + # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) + + return list(param_groups.values()) + + +def get_layer_id_for_vit(name, num_layers): + """ + Assign a parameter with its layer id + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + """ + if name in ["cls_token", "pos_embed"]: + return 0 + elif name.startswith("patch_embed"): + return 0 + elif name.startswith("blocks"): + return int(name.split(".")[1]) + 1 + else: + return num_layers diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_sched.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_sched.py new file mode 100644 index 0000000000000000000000000000000000000000..efe184d8e3fb63ec6b4f83375b6ea719985900de --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_sched.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate with half-cycle cosine after warmup""" + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( + 1.0 + + math.cos( + math.pi + * (epoch - args.warmup_epochs) + / (args.epochs - args.warmup_epochs) + ) + ) + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + return lr diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/misc.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..74184e09e23e0e174350b894b0cff29600c18b71 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/misc.py @@ -0,0 +1,453 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import torch +import torch.distributed as dist +from torch._six import inf + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print("[{}] ".format(now), end="") # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.dist_url = "tcp://%s:%s" % ( + os.environ["MASTER_ADDR"], + os.environ["MASTER_PORT"], + ) + os.environ["LOCAL_RANK"] = str(args.gpu) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}): {}, gpu {}".format( + args.rank, args.dist_url, args.gpu + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__( + self, + loss, + optimizer, + clip_grad=None, + parameters=None, + create_graph=False, + update_grad=True, + ): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_( + optimizer + ) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.0) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm( + torch.stack( + [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] + ), + norm_type, + ) + return total_norm + + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + if loss_scaler is not None: + checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)] + for checkpoint_path in checkpoint_paths: + to_save = { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, + "scaler": loss_scaler.state_dict(), + "args": args, + } + + save_on_master(to_save, checkpoint_path) + else: + client_state = {"epoch": epoch} + model.save_checkpoint( + save_dir=args.output_dir, + tag="checkpoint-%s" % epoch_name, + client_state=client_state, + ) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + if args.resume: + if args.resume.startswith("https"): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location="cpu", check_hash=True + ) + else: + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + print("Resume checkpoint %s" % args.resume) + if ( + "optimizer" in checkpoint + and "epoch" in checkpoint + and not (hasattr(args, "eval") and args.eval) + ): + optimizer.load_state_dict(checkpoint["optimizer"]) + args.start_epoch = checkpoint["epoch"] + 1 + if "scaler" in checkpoint: + loss_scaler.load_state_dict(checkpoint["scaler"]) + print("With optim & sched!") + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +def merge_vmae_to_avmae(avmae_state_dict, vmae_ckpt): + # keys_to_copy=['pos_embed','patch_embed'] + # replaced=0 + + vmae_ckpt["cls_token"] = vmae_ckpt["cls_token_v"] + vmae_ckpt["mask_token"] = vmae_ckpt["mask_token_v"] + + # pos_emb % not trainable, use default + pos_embed_v = vmae_ckpt["pos_embed_v"] # 1,589,768 + pos_embed = pos_embed_v[:, 1:, :] # 1,588,768 + cls_embed = pos_embed_v[:, 0, :].unsqueeze(1) + pos_embed = pos_embed.reshape(1, 2, 14, 14, 768).sum(dim=1) # 1, 14, 14, 768 + print("Position interpolate from 14,14 to 64,8") + pos_embed = pos_embed.permute(0, 3, 1, 2) # 1, 14,14,768 -> 1,768,14,14 + pos_embed = torch.nn.functional.interpolate( + pos_embed, size=(64, 8), mode="bicubic", align_corners=False + ) + pos_embed = pos_embed.permute(0, 2, 3, 1).flatten( + 1, 2 + ) # 1, 14, 14, 768 => 1, 196,768 + pos_embed = torch.cat((cls_embed, pos_embed), dim=1) + assert vmae_ckpt["pos_embed"].shape == pos_embed.shape + vmae_ckpt["pos_embed"] = pos_embed + # patch_emb + # aggregate 3 channels in video-rgb ckpt to 1 channel for audio + v_weight = vmae_ckpt["patch_embed_v.proj.weight"] # 768,3,2,16,16 + new_proj_weight = torch.nn.Parameter(v_weight.sum(dim=2).sum(dim=1).unsqueeze(1)) + assert new_proj_weight.shape == vmae_ckpt["patch_embed.proj.weight"].shape + vmae_ckpt["patch_embed.proj.weight"] = new_proj_weight + vmae_ckpt["patch_embed.proj.bias"] = vmae_ckpt["patch_embed_v.proj.bias"] + + # hack + vmae_ckpt["norm.weight"] = vmae_ckpt["norm_v.weight"] + vmae_ckpt["norm.bias"] = vmae_ckpt["norm_v.bias"] + + # replace transformer encoder + for k, v in vmae_ckpt.items(): + if k.startswith("blocks."): + kk = k.replace("blocks.", "blocks_v.") + vmae_ckpt[k] = vmae_ckpt[kk] + elif k.startswith("blocks_v."): + pass + else: + print(k) + print(k) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/patch_embed.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..ac1e4d436c6f79aef9bf1de32cdac5d4f037c775 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/patch_embed.py @@ -0,0 +1,127 @@ +import torch +import torch.nn as nn +from timm.models.layers import to_2tuple + + +class PatchEmbed_org(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + y = x.flatten(2).transpose(1, 2) + return y + + +class PatchEmbed_new(nn.Module): + """Flexible Image to Patch Embedding""" + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + + self.img_size = img_size + self.patch_size = patch_size + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride + ) # with overlapped patches + # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) + # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w + self.patch_hw = (h, w) + self.num_patches = h * w + + def get_output_shape(self, img_size): + # todo: don't be lazy.. + return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + # x = self.proj(x).flatten(2).transpose(1, 2) + x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12 + x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212 + x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768 + return x + + +class PatchEmbed3D_new(nn.Module): + """Flexible Image to Patch Embedding""" + + def __init__( + self, + video_size=(16, 224, 224), + patch_size=(2, 16, 16), + in_chans=3, + embed_dim=768, + stride=(2, 16, 16), + ): + super().__init__() + + self.video_size = video_size + self.patch_size = patch_size + self.in_chans = in_chans + + self.proj = nn.Conv3d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride + ) + _, _, t, h, w = self.get_output_shape(video_size) # n, emb_dim, h, w + self.patch_thw = (t, h, w) + self.num_patches = t * h * w + + def get_output_shape(self, video_size): + # todo: don't be lazy.. + return self.proj( + torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2]) + ).shape + + def forward(self, x): + B, C, T, H, W = x.shape + x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14 + x = x.flatten(2) # 32, 768, 1568 + x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768 + return x + + +if __name__ == "__main__": + # patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16)) + # input = torch.rand(8,1,1024,128) + # output = patch_emb(input) + # print(output.shape) # (8,512,64) + + patch_emb = PatchEmbed3D_new( + video_size=(6, 224, 224), + patch_size=(2, 16, 16), + in_chans=3, + embed_dim=768, + stride=(2, 16, 16), + ) + input = torch.rand(8, 3, 6, 224, 224) + output = patch_emb(input) + print(output.shape) # (8,64) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/pos_embed.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..2d9177ed98dffcf35264f38aff94e7f00fb50abf --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/pos_embed.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np + +import torch + + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size[0], dtype=np.float32) + grid_w = np.arange(grid_size[1], dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + # omega = np.arange(embed_dim // 2, dtype=np.float) + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + # new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size[0], orig_size[1], new_size[0], new_size[1]) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size[0], orig_size[1], embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size[0], new_size[1]), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + model.pos_embed.shape[-2] - num_patches + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size[0], orig_size[1], new_size[0], new_size[1]) + ) + # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1) + pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove + pos_tokens = pos_tokens.reshape( + -1, orig_size[0], orig_size[1], embedding_size + ) # .permute(0, 3, 1, 2) + # pos_tokens = torch.nn.functional.interpolate( + # pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False) + + # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff + pos_tokens = pos_tokens.flatten(1, 2) + new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +def interpolate_patch_embed_audio( + model, + checkpoint_model, + orig_channel, + new_channel=1, + kernel_size=(16, 16), + stride=(16, 16), + padding=(0, 0), +): + if orig_channel != new_channel: + if "patch_embed.proj.weight" in checkpoint_model: + # aggregate 3 channels in rgb ckpt to 1 channel for audio + new_proj_weight = torch.nn.Parameter( + torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze( + 1 + ) + ) + checkpoint_model["patch_embed.proj.weight"] = new_proj_weight diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/stat.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/stat.py new file mode 100644 index 0000000000000000000000000000000000000000..3f8137249503f6eaa25c3170fe5ef6b87f187347 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/stat.py @@ -0,0 +1,76 @@ +import numpy as np +from scipy import stats +from sklearn import metrics +import torch + + +def d_prime(auc): + standard_normal = stats.norm() + d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) + return d_prime + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +def calculate_stats(output, target): + """Calculate statistics including mAP, AUC, etc. + + Args: + output: 2d array, (samples_num, classes_num) + target: 2d array, (samples_num, classes_num) + + Returns: + stats: list of statistic of each class. + """ + + classes_num = target.shape[-1] + stats = [] + + # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet + acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1)) + + # Class-wise statistics + for k in range(classes_num): + # Average precision + avg_precision = metrics.average_precision_score( + target[:, k], output[:, k], average=None + ) + + # AUC + # auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None) + + # Precisions, recalls + (precisions, recalls, thresholds) = metrics.precision_recall_curve( + target[:, k], output[:, k] + ) + + # FPR, TPR + (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k]) + + save_every_steps = 1000 # Sample statistics to reduce size + dict = { + "precisions": precisions[0::save_every_steps], + "recalls": recalls[0::save_every_steps], + "AP": avg_precision, + "fpr": fpr[0::save_every_steps], + "fnr": 1.0 - tpr[0::save_every_steps], + # 'auc': auc, + # note acc is not class-wise, this is just to keep consistent with other metrics + "acc": acc, + } + stats.append(dict) + + return stats diff --git a/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/__init__.py b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/model.py b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7d49e3e2cda157afcc5b925734f459e4a7d2a6c8 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/model.py @@ -0,0 +1,1069 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from audiosr.latent_diffusion.util import instantiate_from_config +from audiosr.latent_diffusion.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class UpsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=1, padding=2 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class DownsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_ + ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + downsample_time_stride4_levels=[], + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level in self.downsample_time_stride4_levels: + down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) + else: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + downsample_time_stride4_levels=[], + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # compute in_ch_mult, block_in and curr_res at lowest res + (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + # print( + # "Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape) + # ) + # ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level - 1 in self.downsample_time_stride4_levels: + up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) + else: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x).contiguous() + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x + + +class FirstStagePostProcessor(nn.Module): + def __init__( + self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0.0, + pretrained_config=None, + ): + super().__init__() + if pretrained_config is None: + assert ( + pretrained_model is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert ( + pretrained_config is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d( + in_channels, n_channels, kernel_size=3, stride=1, padding=1 + ) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append( + ResnetBlock( + in_channels=ch_in, out_channels=m * n_channels, dropout=dropout + ) + ) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, "b c h w -> b (h w) c") + return z diff --git a/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/openaimodel.py b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..3712857caf29672de9a1dea007ac98f4aaed1a0b --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1103 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from audiosr.latent_diffusion.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from audiosr.latent_diffusion.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1).contiguous() # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context_list=None, mask_list=None): + # The first spatial transformer block does not have context + spatial_transformer_id = 0 + context_list = [None] + context_list + mask_list = [None] + mask_list + + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + if spatial_transformer_id >= len(context_list): + context, mask = None, None + else: + context, mask = ( + context_list[spatial_transformer_id], + mask_list[spatial_transformer_id], + ) + + x = layer(x, context, mask=mask) + spatial_transformer_id += 1 + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1).contiguous() + qkv = self.qkv(self.norm(x)).contiguous() + h = self.attention(qkv).contiguous() + h = self.proj_out(h).contiguous() + return (x + h).reshape(b, c, *spatial).contiguous() + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = ( + qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1) + ) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum( + "bts,bcs->bct", + weight, + v.reshape(bs * self.n_heads, ch, length).contiguous(), + ) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + extra_sa_layer=True, + num_classes=None, + extra_film_condition_dim=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=True, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.extra_film_condition_dim = extra_film_condition_dim + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + # assert not ( + # self.num_classes is not None and self.extra_film_condition_dim is not None + # ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim." + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.use_extra_film_by_concat = self.extra_film_condition_dim is not None + + if self.extra_film_condition_dim is not None: + self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim) + print( + "+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " + % self.extra_film_condition_dim + ) + + if context_dim is not None and not use_spatial_transformer: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + + if context_dim is not None and not isinstance(context_dim, list): + context_dim = [context_dim] + elif context_dim is None: + context_dim = [None] # At least use one spatial transformer + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if extra_sa_layer: + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=None, + ) + ) + for context_dim_id in range(len(context_dim)): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + middle_layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + if extra_sa_layer: + middle_layers.append( + SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=None + ) + ) + for context_dim_id in range(len(context_dim)): + middle_layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + middle_layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + self.middle_block = TimestepEmbedSequential(*middle_layers) + + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if extra_sa_layer: + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=None, + ) + ) + for context_dim_id in range(len(context_dim)): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + self.shape_reported = False + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward( + self, + x, + timesteps=None, + y=None, + context_list=None, + context_attn_mask_list=None, + **kwargs, + ): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional + :return: an [N x C x ...] Tensor of outputs. + """ + if not self.shape_reported: + # print("The shape of UNet input is", x.size()) + self.shape_reported = True + + assert (y is not None) == ( + self.num_classes is not None or self.extra_film_condition_dim is not None + ), "must specify y if and only if the model is class-conditional or film embedding conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + # if self.num_classes is not None: + # assert y.shape == (x.shape[0],) + # emb = emb + self.label_emb(y) + + if self.use_extra_film_by_concat: + emb = th.cat([emb, self.film_emb(y)], dim=-1) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context_list, context_attn_mask_list) + hs.append(h) + h = self.middle_block(h, emb, context_list, context_attn_mask_list) + for module in self.output_blocks: + concate_tensor = hs.pop() + h = th.cat([h, concate_tensor], dim=1) + h = module(h, emb, context_list, context_attn_mask_list) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/util.py b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b1fdc33082e4eed6bba217d15ab09cbe952c0116 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/util.py @@ -0,0 +1,294 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from FlashSR.AudioSR.latent_diffusion.util import instantiate_from_config + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + # betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t).contiguous() + return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous() + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/FlashSR/AudioSR/latent_diffusion/modules/distributions/__init__.py b/FlashSR/AudioSR/latent_diffusion/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/FlashSR/AudioSR/latent_diffusion/modules/distributions/distributions.py b/FlashSR/AudioSR/latent_diffusion/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..58eb535e7769f402169ddff77ee45c96ba3650d9 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/distributions/distributions.py @@ -0,0 +1,102 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/ema.py b/FlashSR/AudioSR/latent_diffusion/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..880ca3d205d9b4d7450e146930a93f2e63c58b70 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/ema.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/encoders/__init__.py b/FlashSR/AudioSR/latent_diffusion/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/FlashSR/AudioSR/latent_diffusion/modules/encoders/modules.py b/FlashSR/AudioSR/latent_diffusion/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..ebec8699ae3591d067a6687c754a9b292c701d05 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/encoders/modules.py @@ -0,0 +1,682 @@ +import torch +import logging +import torch.nn as nn +#from audiosr.clap.open_clip import create_model +#from audiosr.clap.training.data import get_audio_features +import torchaudio +#from transformers import RobertaTokenizer, AutoTokenizer, T5EncoderModel +import torch.nn.functional as F +from audiosr.latent_diffusion.modules.audiomae.AudioMAE import Vanilla_AudioMAE +from audiosr.latent_diffusion.modules.phoneme_encoder.encoder import TextEncoder +from audiosr.latent_diffusion.util import instantiate_from_config + +from transformers import AutoTokenizer, T5Config + + +import numpy as np + +""" +The model forward function can return three types of data: +1. tensor: used directly as conditioning signal +2. dict: where there is a main key as condition, there are also other key that you can use to pass loss function and itermediate result. etc. +3. list: the length is 2, in which the first element is tensor, the second element is attntion mask. + +The output shape for the cross attention condition should be: +x,x_mask = [bs, seq_len, emb_dim], [bs, seq_len] + +All the returned data, in which will be used as diffusion input, will need to be in float type +""" + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class PhonemeEncoder(nn.Module): + def __init__(self, vocabs_size=41, pad_length=250, pad_token_id=None): + super().__init__() + """ + encoder = PhonemeEncoder(40) + data = torch.randint(0, 39, (2, 250)) + output = encoder(data) + import ipdb;ipdb.set_trace() + """ + assert pad_token_id is not None + + self.device = None + self.PAD_LENGTH = int(pad_length) + self.pad_token_id = pad_token_id + self.pad_token_sequence = torch.tensor([self.pad_token_id] * self.PAD_LENGTH) + + self.text_encoder = TextEncoder( + n_vocab=vocabs_size, + out_channels=192, + hidden_channels=192, + filter_channels=768, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.1, + ) + + self.learnable_positional_embedding = torch.nn.Parameter( + torch.zeros((1, 192, self.PAD_LENGTH)) + ) # [batchsize, seqlen, padlen] + self.learnable_positional_embedding.requires_grad = True + + # Required + def get_unconditional_condition(self, batchsize): + unconditional_tokens = self.pad_token_sequence.expand( + batchsize, self.PAD_LENGTH + ) + return self(unconditional_tokens) # Need to return float type + + # def get_unconditional_condition(self, batchsize): + + # hidden_state = torch.zeros((batchsize, self.PAD_LENGTH, 192)).to(self.device) + # attention_mask = torch.ones((batchsize, self.PAD_LENGTH)).to(self.device) + # return [hidden_state, attention_mask] # Need to return float type + + def _get_src_mask(self, phoneme): + src_mask = phoneme != self.pad_token_id + return src_mask + + def _get_src_length(self, phoneme): + src_mask = self._get_src_mask(phoneme) + length = torch.sum(src_mask, dim=-1) + return length + + # def make_empty_condition_unconditional(self, src_length, text_emb, attention_mask): + # # src_length: [bs] + # # text_emb: [bs, 192, pad_length] + # # attention_mask: [bs, pad_length] + # mask = src_length[..., None, None] > 1 + # text_emb = text_emb * mask + + # attention_mask[src_length < 1] = attention_mask[src_length < 1] * 0.0 + 1.0 + # return text_emb, attention_mask + + def forward(self, phoneme_idx): + if self.device is None: + self.device = self.learnable_positional_embedding.device + self.pad_token_sequence = self.pad_token_sequence.to(self.device) + + phoneme_idx = phoneme_idx.to(self.device) + + src_length = self._get_src_length(phoneme_idx) + text_emb, m, logs, text_emb_mask = self.text_encoder(phoneme_idx, src_length) + text_emb = text_emb + self.learnable_positional_embedding + + # text_emb, text_emb_mask = self.make_empty_condition_unconditional(src_length, text_emb, text_emb_mask) + + return [ + text_emb.permute(0, 2, 1), + text_emb_mask.squeeze(1), + ] # [2, 250, 192], [2, 250] + + +class VAEFeatureExtract(nn.Module): + def __init__(self, first_stage_config): + super().__init__() + # self.tokenizer = AutoTokenizer.from_pretrained("gpt2") + self.vae = None + self.instantiate_first_stage(first_stage_config) + self.device = None + self.unconditional_cond = None + + def get_unconditional_condition(self, batchsize): + return self.unconditional_cond.unsqueeze(0).expand(batchsize, -1, -1, -1) + + def instantiate_first_stage(self, config): + self.vae = instantiate_from_config(config) + self.vae.eval() + for p in self.vae.parameters(): + p.requires_grad = False + self.vae.train = disabled_train + + def forward(self, batch): + assert self.vae.training == False + if self.device is None: + self.device = next(self.vae.parameters()).device + + with torch.no_grad(): + vae_embed = self.vae.encode(batch.unsqueeze(1)).sample() + + self.unconditional_cond = -11.4981 + vae_embed[0].clone() * 0.0 + + return vae_embed.detach() + + +class FlanT5HiddenState(nn.Module): + """ + llama = FlanT5HiddenState() + data = ["","this is not an empty sentence"] + encoder_hidden_states = llama(data) + import ipdb;ipdb.set_trace() + """ + + def __init__( + self, text_encoder_name="google/flan-t5-large", freeze_text_encoder=True + ): + super().__init__() + self.freeze_text_encoder = freeze_text_encoder + self.tokenizer = AutoTokenizer.from_pretrained(text_encoder_name) + #self.model = T5EncoderModel(T5Config.from_pretrained(text_encoder_name)) + if freeze_text_encoder: + self.model.eval() + for p in self.model.parameters(): + p.requires_grad = False + else: + print("=> The text encoder is learnable") + + self.empty_hidden_state_cfg = None + self.device = None + + # Required + def get_unconditional_condition(self, batchsize): + param = next(self.model.parameters()) + if self.freeze_text_encoder: + assert param.requires_grad == False + + # device = param.device + if self.empty_hidden_state_cfg is None: + self.empty_hidden_state_cfg, _ = self([""]) + + hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float() + attention_mask = ( + torch.ones((batchsize, hidden_state.size(1))) + .to(hidden_state.device) + .float() + ) + return [hidden_state, attention_mask] # Need to return float type + + def forward(self, batch): + param = next(self.model.parameters()) + if self.freeze_text_encoder: + assert param.requires_grad == False + + if self.device is None: + self.device = param.device + + # print("Manually change text") + # for i in range(len(batch)): + # batch[i] = "dog barking" + try: + return self.encode_text(batch) + except Exception as e: + print(e, batch) + logging.exception("An error occurred: %s", str(e)) + + def encode_text(self, prompt): + device = self.model.device + batch = self.tokenizer( + prompt, + max_length=128, # self.tokenizer.model_max_length + padding=True, + truncation=True, + return_tensors="pt", + ) + input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to( + device + ) + # Get text encoding + if self.freeze_text_encoder: + with torch.no_grad(): + encoder_hidden_states = self.model( + input_ids=input_ids, attention_mask=attention_mask + )[0] + else: + encoder_hidden_states = self.model( + input_ids=input_ids, attention_mask=attention_mask + )[0] + return [ + encoder_hidden_states.detach(), + attention_mask.float(), + ] + + +class AudioMAEConditionCTPoolRandTFSeparated(nn.Module): + """ + audiomae = AudioMAEConditionCTPool2x2() + data = torch.randn((4, 1024, 128)) + output = audiomae(data) + import ipdb;ipdb.set_trace() + exit(0) + """ + + def __init__( + self, + time_pooling_factors=[1, 2, 4, 8], + freq_pooling_factors=[1, 2, 4, 8], + eval_time_pooling=None, + eval_freq_pooling=None, + mask_ratio=0.0, + regularization=False, + no_audiomae_mask=True, + no_audiomae_average=False, + ): + super().__init__() + self.device = None + self.time_pooling_factors = time_pooling_factors + self.freq_pooling_factors = freq_pooling_factors + self.no_audiomae_mask = no_audiomae_mask + self.no_audiomae_average = no_audiomae_average + + self.eval_freq_pooling = eval_freq_pooling + self.eval_time_pooling = eval_time_pooling + self.mask_ratio = mask_ratio + self.use_reg = regularization + + self.audiomae = Vanilla_AudioMAE() + self.audiomae.eval() + for p in self.audiomae.parameters(): + p.requires_grad = False + + # Required + def get_unconditional_condition(self, batchsize): + param = next(self.audiomae.parameters()) + assert param.requires_grad == False + device = param.device + # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] + # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] + token_num = int(512 / (time_pool * freq_pool)) + return [ + torch.zeros((batchsize, token_num, 768)).to(device).float(), + torch.ones((batchsize, token_num)).to(device).float(), + ] + + def pool(self, representation, time_pool=None, freq_pool=None): + assert representation.size(-1) == 768 + representation = representation[:, 1:, :].transpose(1, 2) + bs, embedding_dim, token_num = representation.size() + representation = representation.reshape(bs, embedding_dim, 64, 8) + + if self.training: + if time_pool is None and freq_pool is None: + time_pool = min( + 64, + self.time_pooling_factors[ + np.random.choice(list(range(len(self.time_pooling_factors)))) + ], + ) + freq_pool = min( + 8, + self.freq_pooling_factors[ + np.random.choice(list(range(len(self.freq_pooling_factors)))) + ], + ) + # freq_pool = min(8, time_pool) # TODO here I make some modification. + else: + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + + self.avgpooling = nn.AvgPool2d( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + self.maxpooling = nn.MaxPool2d( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + + pooled = ( + self.avgpooling(representation) + self.maxpooling(representation) + ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] + pooled = pooled.flatten(2).transpose(1, 2) + return pooled # [bs, token_num, embedding_dim] + + def regularization(self, x): + assert x.size(-1) == 768 + x = F.normalize(x, p=2, dim=-1) + return x + + # Required + def forward(self, batch, time_pool=None, freq_pool=None): + assert batch.size(-2) == 1024 and batch.size(-1) == 128 + + if self.device is None: + self.device = batch.device + + batch = batch.unsqueeze(1) + with torch.no_grad(): + representation = self.audiomae( + batch, + mask_ratio=self.mask_ratio, + no_mask=self.no_audiomae_mask, + no_average=self.no_audiomae_average, + ) + representation = self.pool(representation, time_pool, freq_pool) + if self.use_reg: + representation = self.regularization(representation) + return [ + representation, + torch.ones((representation.size(0), representation.size(1))) + .to(representation.device) + .float(), + ] + + +class AudioMAEConditionCTPoolRand(nn.Module): + """ + audiomae = AudioMAEConditionCTPool2x2() + data = torch.randn((4, 1024, 128)) + output = audiomae(data) + import ipdb;ipdb.set_trace() + exit(0) + """ + + def __init__( + self, + time_pooling_factors=[1, 2, 4, 8], + freq_pooling_factors=[1, 2, 4, 8], + eval_time_pooling=None, + eval_freq_pooling=None, + mask_ratio=0.0, + regularization=False, + no_audiomae_mask=True, + no_audiomae_average=False, + ): + super().__init__() + self.device = None + self.time_pooling_factors = time_pooling_factors + self.freq_pooling_factors = freq_pooling_factors + self.no_audiomae_mask = no_audiomae_mask + self.no_audiomae_average = no_audiomae_average + + self.eval_freq_pooling = eval_freq_pooling + self.eval_time_pooling = eval_time_pooling + self.mask_ratio = mask_ratio + self.use_reg = regularization + + self.audiomae = Vanilla_AudioMAE() + self.audiomae.eval() + for p in self.audiomae.parameters(): + p.requires_grad = False + + # Required + def get_unconditional_condition(self, batchsize): + param = next(self.audiomae.parameters()) + assert param.requires_grad == False + device = param.device + # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] + # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] + token_num = int(512 / (time_pool * freq_pool)) + return [ + torch.zeros((batchsize, token_num, 768)).to(device).float(), + torch.ones((batchsize, token_num)).to(device).float(), + ] + + def pool(self, representation, time_pool=None, freq_pool=None): + assert representation.size(-1) == 768 + representation = representation[:, 1:, :].transpose(1, 2) + bs, embedding_dim, token_num = representation.size() + representation = representation.reshape(bs, embedding_dim, 64, 8) + + if self.training: + if time_pool is None and freq_pool is None: + time_pool = min( + 64, + self.time_pooling_factors[ + np.random.choice(list(range(len(self.time_pooling_factors)))) + ], + ) + # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] + freq_pool = min(8, time_pool) # TODO here I make some modification. + else: + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + + self.avgpooling = nn.AvgPool2d( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + self.maxpooling = nn.MaxPool2d( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + + pooled = ( + self.avgpooling(representation) + self.maxpooling(representation) + ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] + pooled = pooled.flatten(2).transpose(1, 2) + return pooled # [bs, token_num, embedding_dim] + + def regularization(self, x): + assert x.size(-1) == 768 + x = F.normalize(x, p=2, dim=-1) + return x + + # Required + def forward(self, batch, time_pool=None, freq_pool=None): + assert batch.size(-2) == 1024 and batch.size(-1) == 128 + + if self.device is None: + self.device = next(self.audiomae.parameters()).device + + batch = batch.unsqueeze(1).to(self.device) + with torch.no_grad(): + representation = self.audiomae( + batch, + mask_ratio=self.mask_ratio, + no_mask=self.no_audiomae_mask, + no_average=self.no_audiomae_average, + ) + representation = self.pool(representation, time_pool, freq_pool) + if self.use_reg: + representation = self.regularization(representation) + return [ + representation, + torch.ones((representation.size(0), representation.size(1))) + .to(representation.device) + .float(), + ] + + +class CLAPAudioEmbeddingClassifierFreev2(nn.Module): + def __init__( + self, + pretrained_path="", + enable_cuda=False, + sampling_rate=16000, + embed_mode="audio", + amodel="HTSAT-base", + unconditional_prob=0.1, + random_mute=False, + max_random_mute_portion=0.5, + training_mode=True, + ): + super().__init__() + self.device = "cpu" # The model itself is on cpu + self.cuda = enable_cuda + self.precision = "fp32" + self.amodel = amodel # or 'PANN-14' + self.tmodel = "roberta" # the best text encoder in our training + self.enable_fusion = False # False if you do not want to use the fusion model + self.fusion_type = "aff_2d" + self.pretrained = pretrained_path + self.embed_mode = embed_mode + self.embed_mode_orig = embed_mode + self.sampling_rate = sampling_rate + self.unconditional_prob = unconditional_prob + self.random_mute = random_mute + #self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") + self.max_random_mute_portion = max_random_mute_portion + self.training_mode = training_mode + self.model, self.model_cfg = create_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device=self.device, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + self.model = self.model.to(self.device) + audio_cfg = self.model_cfg["audio_cfg"] + self.mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=audio_cfg["sample_rate"], + n_fft=audio_cfg["window_size"], + win_length=audio_cfg["window_size"], + hop_length=audio_cfg["hop_size"], + center=True, + pad_mode="reflect", + power=2.0, + norm=None, + onesided=True, + n_mels=64, + f_min=audio_cfg["fmin"], + f_max=audio_cfg["fmax"], + ) + for p in self.model.parameters(): + p.requires_grad = False + self.unconditional_token = None + self.model.eval() + + def get_unconditional_condition(self, batchsize): + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0) + + def batch_to_list(self, batch): + ret = [] + for i in range(batch.size(0)): + ret.append(batch[i]) + return ret + + def make_decision(self, probability): + if float(torch.rand(1)) < probability: + return True + else: + return False + + def random_uniform(self, start, end): + val = torch.rand(1).item() + return start + (end - start) * val + + def _random_mute(self, waveform): + # waveform: [bs, t-steps] + t_steps = waveform.size(-1) + for i in range(waveform.size(0)): + mute_size = int( + self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) + ) + mute_start = int(self.random_uniform(0, t_steps - mute_size)) + waveform[i, mute_start : mute_start + mute_size] = 0 + return waveform + + def cos_similarity(self, waveform, text): + # waveform: [bs, t_steps] + original_embed_mode = self.embed_mode + with torch.no_grad(): + self.embed_mode = "audio" + # MPS currently does not support ComplexFloat dtype and operator 'aten::_fft_r2c' + if self.cuda: + audio_emb = self(waveform.cuda()) + else: + audio_emb = self(waveform.to("cpu")) + self.embed_mode = "text" + text_emb = self(text) + similarity = F.cosine_similarity(audio_emb, text_emb, dim=2) + self.embed_mode = original_embed_mode + return similarity.squeeze() + + def build_unconditional_emb(self): + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + + def forward(self, batch): + # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0 + # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0 + if self.model.training == True and not self.training_mode: + print( + "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." + ) + self.model, self.model_cfg = create_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device="cuda" if self.cuda else "cpu", + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + for p in self.model.parameters(): + p.requires_grad = False + self.model.eval() + + if self.unconditional_token is None: + self.build_unconditional_emb() + + # if(self.training_mode): + # assert self.model.training == True + # else: + # assert self.model.training == False + + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + if self.embed_mode == "audio": + if not self.training: + print("INFO: clap model calculate the audio embedding as condition") + with torch.no_grad(): + # assert ( + # self.sampling_rate == 16000 + # ), "We only support 16000 sampling rate" + + # if self.random_mute: + # batch = self._random_mute(batch) + # batch: [bs, 1, t-samples] + if self.sampling_rate != 48000: + batch = torchaudio.functional.resample( + batch, orig_freq=self.sampling_rate, new_freq=48000 + ) + audio_data = batch.squeeze(1).to("cpu") + self.mel_transform = self.mel_transform.to(audio_data.device) + mel = self.mel_transform(audio_data) + audio_dict = get_audio_features( + audio_data, + mel, + 480000, + data_truncating="fusion", + data_filling="repeatpad", + audio_cfg=self.model_cfg["audio_cfg"], + ) + # [bs, 512] + embed = self.model.get_audio_embedding(audio_dict) + elif self.embed_mode == "text": + with torch.no_grad(): + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + text_data = self.tokenizer(batch) + + if isinstance(batch, str) or ( + isinstance(batch, list) and len(batch) == 1 + ): + for key in text_data.keys(): + text_data[key] = text_data[key].unsqueeze(0) + + embed = self.model.get_text_embedding(text_data) + + embed = embed.unsqueeze(1) + for i in range(embed.size(0)): + if self.make_decision(self.unconditional_prob): + embed[i] = self.unconditional_token + # embed = torch.randn((batch.size(0), 1, 512)).type_as(batch) + return embed.detach() + + def tokenizer(self, text): + result = self.tokenize( + text, + padding="max_length", + truncation=True, + max_length=512, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/__init__.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/attentions.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..287c409dbfe535d95c65a2c69d013965fce97348 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/attentions.py @@ -0,0 +1,430 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +import audiosr.latent_diffusion.modules.phoneme_encoder.commons as commons + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + device=x.device, dtype=x.dtype + ) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert ( + t_s == t_t + ), "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys( + query / math.sqrt(self.k_channels), key_relative_embeddings + ) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert ( + t_s == t_t + ), "Local attention is only available for self-attention." + block_mask = ( + torch.ones_like(scores) + .triu(-self.block_length) + .tril(self.block_length) + ) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s + ) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings + ) + output = ( + output.transpose(2, 3).contiguous().view(b, d, t_t) + ) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[ + :, slice_start_position:slice_end_position + ] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad( + x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) + ) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad( + x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) + ) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/commons.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..9515724c12ab2f856b9a2ec14e38cc63df9b85d6 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/commons.py @@ -0,0 +1,161 @@ +import math +import torch +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/encoder.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3085c72c98627eb66e7c944b18b218f5fcee5321 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/encoder.py @@ -0,0 +1,50 @@ +import math +import torch +from torch import nn + +import audiosr.latent_diffusion.modules.phoneme_encoder.commons as commons +import audiosr.latent_diffusion.modules.phoneme_encoder.attentions as attentions + + +class TextEncoder(nn.Module): + def __init__( + self, + n_vocab, + out_channels=192, + hidden_channels=192, + filter_channels=768, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.1, + ): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/LICENSE b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4ad4ed1d5e34d95c8380768ec16405d789cc6de4 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2017 Keith Ito + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/__init__.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cea4352df38cbedbf7f8e30cacd2a14eb0eb17f2 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/__init__.py @@ -0,0 +1,52 @@ +""" from https://github.com/keithito/tacotron """ +from audiosr.latent_diffusion.modules.phoneme_encoder.text import cleaners +from audiosr.latent_diffusion.modules.phoneme_encoder.text.symbols import symbols + + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} + +cleaner = getattr(cleaners, "english_cleaners2") + + +def text_to_sequence(text, cleaner_names): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [] + + clean_text = _clean_text(text, cleaner_names) + for symbol in clean_text: + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + return sequence + + +def cleaned_text_to_sequence(cleaned_text): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] + return sequence + + +def sequence_to_text(sequence): + """Converts a sequence of IDs back to a string""" + result = "" + for symbol_id in sequence: + s = _id_to_symbol[symbol_id] + result += s + return result + + +def _clean_text(text, cleaner_names): + text = cleaner(text) + return text diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/cleaners.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..18fd8193fb280f6e10688dc78d61486683f2e41b --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/cleaners.py @@ -0,0 +1,110 @@ +""" from https://github.com/keithito/tacotron """ + +""" +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +""" + +import re +#from unidecode import unidecode +#from phonemizer import phonemize + + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r"\s+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + """Pipeline for non-English text that transliterates to ASCII.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + """Pipeline for English text, including abbreviation expansion.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = phonemize(text, language="en-us", backend="espeak", strip=True) + phonemes = collapse_whitespace(phonemes) + return phonemes + + +def english_cleaners2(text): + """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = phonemize( + text, + language="en-us", + backend="espeak", + strip=True, + preserve_punctuation=True, + with_stress=True, + ) + phonemes = collapse_whitespace(phonemes) + return phonemes diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/symbols.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..b419a4e6bcdbe663617c2edebff9100ab09baa6c --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/symbols.py @@ -0,0 +1,16 @@ +""" from https://github.com/keithito/tacotron """ + +""" +Defines the set of symbols used in text input to the model. +""" +_pad = "_" +_punctuation = ';:,.!?¡¿—…"«»“” ' +_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" + + +# Export all symbols: +symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + +# Special symbol ids +SPACE_ID = symbols.index(" ") diff --git a/FlashSR/AudioSR/latent_diffusion/util.py b/FlashSR/AudioSR/latent_diffusion/util.py new file mode 100644 index 0000000000000000000000000000000000000000..47225418b8d35a6a010a4a44a7f8bfd013c0ac89 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/util.py @@ -0,0 +1,267 @@ +import importlib + +import torch +import numpy as np +from collections import abc + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + +CACHE = { + "get_vits_phoneme_ids": { + "PAD_LENGTH": 310, + "_pad": "_", + "_punctuation": ';:,.!?¡¿—…"«»“” ', + "_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", + "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ", + "_special": "♪☎☒☝⚠", + } +} + +CACHE["get_vits_phoneme_ids"]["symbols"] = ( + [CACHE["get_vits_phoneme_ids"]["_pad"]] + + list(CACHE["get_vits_phoneme_ids"]["_punctuation"]) + + list(CACHE["get_vits_phoneme_ids"]["_letters"]) + + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"]) + + list(CACHE["get_vits_phoneme_ids"]["_special"]) +) +CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = { + s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"]) +} + + +def get_vits_phoneme_ids_no_padding(phonemes): + pad_token_id = 0 + pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] + _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] + batchsize = len(phonemes) + + clean_text = phonemes[0] + "⚠" + sequence = [] + + for symbol in clean_text: + if symbol not in _symbol_to_id.keys(): + print("%s is not in the vocabulary. %s" % (symbol, clean_text)) + symbol = "_" + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + + def _pad_phonemes(phonemes_list): + return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list)) + + sequence = sequence[:pad_length] + + return { + "phoneme_idx": torch.LongTensor(_pad_phonemes(sequence)) + .unsqueeze(0) + .expand(batchsize, -1) + } + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join( + xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) + ) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def float32_to_int16(x): + x = np.clip(x, a_min=-1.0, a_max=1.0) + return (x * 32767.0).astype(np.int16) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, + data, + n_proc, + target_data_type="ndarray", + cpu_intensive=True, + use_worker_id=False, +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i : i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == "ndarray": + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == "list": + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/FlashSR/BigVGAN/LICENSE b/FlashSR/BigVGAN/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e9663595cc28938f88d6299acd3ba791542e4c0c --- /dev/null +++ b/FlashSR/BigVGAN/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 NVIDIA CORPORATION. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/FlashSR/BigVGAN/LibriTTS/dev-clean.txt b/FlashSR/BigVGAN/LibriTTS/dev-clean.txt new file mode 100644 index 0000000000000000000000000000000000000000..563b86e601c12604b511548fe50a21d57524e438 --- /dev/null +++ b/FlashSR/BigVGAN/LibriTTS/dev-clean.txt @@ -0,0 +1,115 @@ +dev-clean/1272/128104/1272_128104_000001_000000|A 'JOLLY' ART CRITIC +dev-clean/1272/141231/1272_141231_000007_000003|And when he attacked, it was always there to beat him aside. +dev-clean/1272/141231/1272_141231_000033_000002|If anything, he was pressing the attack. +dev-clean/1462/170138/1462_170138_000012_000002|Dear me, Mac, the girl couldn't possibly be better, you know." +dev-clean/1462/170142/1462_170142_000002_000005|Alexander did not sit down. +dev-clean/1462/170142/1462_170142_000029_000001|"I meant to, but somehow I couldn't. +dev-clean/1462/170142/1462_170142_000046_000004|The sight of you, Bartley, to see you living and happy and successful-can I never make you understand what that means to me?" She pressed his shoulders gently. +dev-clean/1462/170145/1462_170145_000012_000003|There is a letter for you there, in my desk drawer. +dev-clean/1462/170145/1462_170145_000033_000000|She felt the strength leap in the arms that held her so lightly. +dev-clean/1673/143397/1673_143397_000031_000007|He attempted to remove or intimidate the leaders by a common sentence, of acquittal or condemnation; he invested his representatives at Ephesus with ample power and military force; he summoned from either party eight chosen deputies to a free and candid conference in the neighborhood of the capital, far from the contagion of popular frenzy. +dev-clean/174/168635/174_168635_000040_000000|To teach Cosette to read, and to let her play, this constituted nearly the whole of Jean Valjean's existence. +dev-clean/174/50561/174_50561_000058_000001|They have the end of the game to themselves.) +dev-clean/174/84280/174_84280_000015_000000|And perhaps in this story I have said enough for you to understand why Mary has identified herself with something world-wide, has added to herself a symbolical value, and why it is I find in the whole crowded spectacle of mankind, a quality that is also hers, a sense of fine things entangled and stifled and unable to free themselves from the ancient limiting jealousies which law and custom embody. +dev-clean/1919/142785/1919_142785_000063_000000|[Illustration: SHALOT.] +dev-clean/1919/142785/1919_142785_000131_000001|Cut the bread into thin slices, place them in a cool oven overnight, and when thoroughly dry and crisp, roll them down into fine crumbs. +dev-clean/1988/147956/1988_147956_000016_000009|He was neatly dressed. +dev-clean/1988/148538/1988_148538_000015_000007|These persons then displayed towards each other precisely the same puerile jealousies which animate the men of democracies, the same eagerness to snatch the smallest advantages which their equals contested, and the same desire to parade ostentatiously those of which they were in possession. +dev-clean/1988/24833/1988_24833_000028_000003|He's taking the kid for a walk when a thunderstorm blows up. +dev-clean/1988/24833/1988_24833_000059_000000|"Doesn't pay enough?" Pop asks. +dev-clean/1993/147149/1993_147149_000051_000002|So leaving kind messages to George and Jane Wilson, and hesitating whether she might dare to send a few kind words to Jem, and deciding that she had better not, she stepped out into the bright morning light, so fresh a contrast to the darkened room where death had been. +dev-clean/1993/147965/1993_147965_000003_000004|I suppose, in the crowded clutter of their cave, the old man had come to believe that peace and order had vanished from the earth, or existed only in the old world he had left so far behind. +dev-clean/1993/147966/1993_147966_000020_000003|We found the chickens asleep; perhaps they thought night had come to stay. +dev-clean/2035/147960/2035_147960_000019_000001|He is all over Jimmy's boots. I scream for him to run, but he just hit and hit that snake like he was crazy." +dev-clean/2035/147961/2035_147961_000011_000002|He grew more and more excited, and kept pointing all around his bed, as if there were things there and he wanted mr Shimerda to see them. +dev-clean/2035/147961/2035_147961_000025_000002|Beside a frozen pond something happened to the other sledge; peter saw it plainly. +dev-clean/2035/152373/2035_152373_000010_000007|saint Aidan, the Apostle of Northumbria, had refused the late Egfrid's father absolution, on one occasion, until he solemnly promised to restore their freedom to certain captives of this description. +dev-clean/2086/149214/2086_149214_000005_000002|It is a legend prolonging itself, from an epoch now gray in the distance, down into our own broad daylight, and bringing along with it some of its legendary mist, which the reader, according to his pleasure, may either disregard, or allow it to float almost imperceptibly about the characters and events for the sake of a picturesque effect. +dev-clean/2086/149220/2086_149220_000016_000003|In short, I make pictures out of sunshine; and, not to be too much dazzled with my own trade, I have prevailed with Miss Hepzibah to let me lodge in one of these dusky gables. +dev-clean/2086/149220/2086_149220_000028_000000|Phoebe was on the point of retreating, but turned back, with some hesitation; for she did not exactly comprehend his manner, although, on better observation, its feature seemed rather to be lack of ceremony than any approach to offensive rudeness. +dev-clean/2277/149874/2277_149874_000007_000001|Her husband asked a few questions and sat down to read the evening paper. +dev-clean/2277/149896/2277_149896_000007_000006|He saw only her pretty face and neat figure and wondered why life was not arranged so that such joy as he found with her could be steadily maintained. +dev-clean/2277/149896/2277_149896_000025_000008|He jangled it fiercely several times in succession, but without avail. +dev-clean/2277/149897/2277_149897_000023_000000|"Well?" said Hurstwood. +dev-clean/2277/149897/2277_149897_000046_000002|He troubled over many little details and talked perfunctorily to everybody. +dev-clean/2412/153954/2412_153954_000004_000005|Even in middle age they were still comely, and the old grey haired women at their cottage doors had a dignity, not to say majesty, of their own. +dev-clean/2428/83699/2428_83699_000009_000000|Now it is a remarkable thing that I have always had an extraordinary predilection for the name Madge. +dev-clean/2428/83699/2428_83699_000024_000004|I had long been wishing that an old-fashioned Christmas had been completely extinct before I had thought of adventuring in quest of one. +dev-clean/2428/83699/2428_83699_000047_000000|"Perhaps you had better come inside." +dev-clean/2428/83705/2428_83705_000015_000004|I did not want any unpleasantness; and I am quite sure there would have been unpleasantness had I demurred. +dev-clean/2428/83705/2428_83705_000034_000002|"And what," inquired mrs Macpherson, "has Mary Ann given you?" +dev-clean/251/118436/251_118436_000017_000001|This man was clad in a brown camel hair robe and sandals, and a green turban was on his head. His expression was tranquil, his gaze impersonal. +dev-clean/251/136532/251_136532_000000_000003|Fitzgerald was still trying to find out how the germ had been transmitted. +dev-clean/251/136532/251_136532_000020_000004|Without question, he had become, overnight, the most widely known archaeologist in history. +dev-clean/251/137823/251_137823_000025_000001|Or grazed, at least," Tom added thankfully. +dev-clean/251/137823/251_137823_000054_000002|The two girls were as much upset as Tom's mother. +dev-clean/2803/154320/2803_154320_000017_000004|Think of Lady Glenarvan; think of Mary Grant!" +dev-clean/2803/154328/2803_154328_000028_000000|Wilson and Olbinett joined their companions, and all united to dig through the wall-john with his dagger, the others with stones taken from the ground, or with their nails, while Mulrady, stretched along the ground, watched the native guard through a crevice of the matting. +dev-clean/2803/154328/2803_154328_000080_000003|Where chance led them, but at any rate they were free. +dev-clean/2803/161169/2803_161169_000011_000019|What do you think of that from the coal tar. +dev-clean/2902/9008/2902_9008_000009_000001|He was a Greek, also, but of a more common, and, perhaps, lower type; dark and fiery, thin and graceful; his delicate figure and cheeks, wasted by meditation, harmonised well with the staid and simple philosophic cloak which he wore as a sign of his profession. +dev-clean/2902/9008/2902_9008_000048_000003|For aught I know or care, the plot may be an exactly opposite one, and the Christians intend to murder all the Jews. +dev-clean/3000/15664/3000_15664_000013_000004|These volcanic caves are not wanting in interest, and it is well to light a pitch pine torch and take a walk in these dark ways of the underworld whenever opportunity offers, if for no other reason to see with new appreciation on returning to the sunshine the beauties that lie so thick about us. +dev-clean/3000/15664/3000_15664_000029_000002|Thus the Shasta River issues from a large lake like spring in Shasta Valley, and about two thirds of the volume of the McCloud gushes forth in a grand spring on the east side of the mountain, a few miles back from its immediate base. +dev-clean/3170/137482/3170_137482_000010_000004|The nobility, the merchants, even workmen in good circumstances, are never seen in the 'magazzino', for cleanliness is not exactly worshipped in such places. +dev-clean/3170/137482/3170_137482_000037_000001|He was celebrated in Venice not only for his eloquence and his great talents as a statesman, but also for the gallantries of his youth. +dev-clean/3536/23268/3536_23268_000028_000000|"It is not the first time, I believe, you have acted contrary to that, Miss Milner," replied mrs Horton, and affected a tenderness of voice, to soften the harshness of her words. +dev-clean/3576/138058/3576_138058_000019_000003|He wondered to see the lance leaning against the tree, the shield on the ground, and Don Quixote in armour and dejected, with the saddest and most melancholy face that sadness itself could produce; and going up to him he said, "Be not so cast down, good man, for you have not fallen into the hands of any inhuman Busiris, but into Roque Guinart's, which are more merciful than cruel." +dev-clean/3752/4943/3752_4943_000026_000002|Lie quiet!" +dev-clean/3752/4943/3752_4943_000056_000002|His flogging wouldn't have killed a flea." +dev-clean/3752/4944/3752_4944_000031_000000|"Well now!" said Meekin, with asperity, "I don't agree with you. Everybody seems to be against that poor fellow-Captain Frere tried to make me think that his letters contained a hidden meaning, but I don't believe they did. +dev-clean/3752/4944/3752_4944_000063_000003|He'd rather kill himself." +dev-clean/3752/4944/3752_4944_000094_000000|"The Government may go to----, and you, too!" roared Burgess. +dev-clean/3853/163249/3853_163249_000058_000000|"I've done it, mother: tell me you're not sorry." +dev-clean/3853/163249/3853_163249_000125_000004|Help me to be brave and strong, David: don't let me complain or regret, but show me what lies beyond, and teach me to believe that simply doing the right is reward and happiness enough." +dev-clean/5338/24615/5338_24615_000004_000003|It had been built at a period when castles were no longer necessary, and when the Scottish architects had not yet acquired the art of designing a domestic residence. +dev-clean/5338/284437/5338_284437_000031_000001|A powerful ruler ought to be rich and to live in a splendid palace. +dev-clean/5536/43358/5536_43358_000012_000001|Being a natural man, the Indian was intensely poetical. +dev-clean/5536/43359/5536_43359_000015_000000|The family was not only the social unit, but also the unit of government. +dev-clean/5694/64025/5694_64025_000004_000006|Our regiment was the advance guard on Saturday evening, and did a little skirmishing; but General Gladden's brigade passed us and assumed a position in our immediate front. +dev-clean/5694/64029/5694_64029_000006_000005|I read it, and looked up to hand it back to him, when I discovered that he had a pistol cocked and leveled in my face, and says he, "Drop that gun; you are my prisoner." I saw there was no use in fooling about it. +dev-clean/5694/64029/5694_64029_000024_000002|The ground was literally covered with blue coats dead; and, if I remember correctly, there were eighty dead horses. +dev-clean/5694/64038/5694_64038_000015_000002|I could not imagine what had become of him. +dev-clean/5895/34615/5895_34615_000013_000003|Man can do nothing to create beauty, but everything to produce ugliness. +dev-clean/5895/34615/5895_34615_000025_000000|With this exception, Gwynplaine's laugh was everlasting. +dev-clean/5895/34622/5895_34622_000029_000002|In the opposite corner was the kitchen. +dev-clean/5895/34629/5895_34629_000021_000005|The sea is a wall; and if Voltaire-a thing which he very much regretted when it was too late-had not thrown a bridge over to Shakespeare, Shakespeare might still be in England, on the other side of the wall, a captive in insular glory. +dev-clean/6241/61943/6241_61943_000020_000000|My uncle came out of his cabin pale, haggard, thin, but full of enthusiasm, his eyes dilated with pleasure and satisfaction. +dev-clean/6241/61946/6241_61946_000014_000000|The rugged summits of the rocky hills were dimly visible on the edge of the horizon, through the misty fogs; every now and then some heavy flakes of snow showed conspicuous in the morning light, while certain lofty and pointed rocks were first lost in the grey low clouds, their summits clearly visible above, like jagged reefs rising from a troublous sea. +dev-clean/6241/61946/6241_61946_000051_000001|Then my uncle, myself, and guide, two boatmen and the four horses got into a very awkward flat bottom boat. +dev-clean/6295/64301/6295_64301_000010_000002|The music was broken, and Joseph left alone with the dumb instruments. +dev-clean/6313/66125/6313_66125_000020_000002|"Are you hurt?" +dev-clean/6313/66125/6313_66125_000053_000000|"Are you ready?" +dev-clean/6313/66129/6313_66129_000011_000001|"Cold water is the most nourishing thing we've touched since last night." +dev-clean/6313/66129/6313_66129_000045_000004|Of course, dogs can't follow the trail of an animal as well, now, as they could with snow on the ground. +dev-clean/6313/66129/6313_66129_000081_000000|Stacy dismounted and removed the hat carefully to one side. +dev-clean/6313/76958/6313_76958_000029_000000|Instantly there was a chorus of yells and snarls from the disturbed cowpunchers, accompanied by dire threats as to what they would do to the gopher did he ever disturb their rest in that way again. +dev-clean/6313/76958/6313_76958_000073_000001|"Those fellows have to go out. +dev-clean/6319/275224/6319_275224_000014_000001|And what is the matter with the beautiful straggling branches, that they are to be cut off as fast as they appear? +dev-clean/6319/57405/6319_57405_000019_000000|"It is rather a silly thing to do," said Deucalion; "and yet there can be no harm in it, and we shall see what will happen." +dev-clean/6319/64726/6319_64726_000017_000002|Then the prince took the princess by the hand; she was dressed in great splendour, but he did not hint that she looked as he had seen pictures of his great grandmother look; he thought her all the more charming for that. +dev-clean/6345/93302/6345_93302_000000_000001|All LibriVox recordings are in the public domain. +dev-clean/6345/93302/6345_93302_000049_000000|The fine tact of a noble woman seemed to have deserted her. +dev-clean/6345/93302/6345_93302_000073_000000|So she said- +dev-clean/6345/93306/6345_93306_000024_000002|What is it? +dev-clean/652/130737/652_130737_000031_000001|Good aroma. +dev-clean/7850/111771/7850_111771_000009_000001|After various flanking movements and costly assaults, the problem of taking Lee narrowed itself down to a siege of Petersburg. +dev-clean/7850/281318/7850_281318_000012_000000|She began to show them how to weave the bits of things together into nests, as they should be made. +dev-clean/7850/286674/7850_286674_000006_000001|You would think that, with six legs apiece and three joints in each leg, they might walk quite fast, yet they never did. +dev-clean/7850/73752/7850_73752_000006_000003|What a Neapolitan ball was his career then! +dev-clean/7976/105575/7976_105575_000009_000000|The burying party the next morning found nineteen dead Rebels lying together at one place. +dev-clean/7976/105575/7976_105575_000017_000000|Our regiment now pursued the flying Rebels with great vigor. +dev-clean/7976/110124/7976_110124_000021_000001|"We two are older and wiser than you are. It is for us to determine what shall be done. +dev-clean/7976/110124/7976_110124_000053_000002|The doors were strong and held securely. +dev-clean/7976/110523/7976_110523_000027_000000|"We will go in here," said Hansel, "and have a glorious feast. +dev-clean/8297/275154/8297_275154_000008_000000|Was this man-haggard, pallid, shabby, looking at him piteously with bloodshot eyes-the handsome, pleasant, prosperous brother whom he remembered? +dev-clean/8297/275154/8297_275154_000024_000011|Tell me where my wife is living now?" +dev-clean/8297/275155/8297_275155_000013_000006|What a perfect gentleman!" +dev-clean/8297/275155/8297_275155_000037_000000|"Say thoroughly worthy of the course forced upon me and my daughter by your brother's infamous conduct-and you will be nearer the mark!" +dev-clean/8297/275156/8297_275156_000013_000005|No more of it now. +dev-clean/84/121123/84_121123_000009_000000|But in less than five minutes the staircase groaned beneath an extraordinary weight. +dev-clean/84/121123/84_121123_000054_000000|It was something terrible to witness the silent agony, the mute despair of Noirtier, whose tears silently rolled down his cheeks. +dev-clean/84/121550/84_121550_000064_000000|And lo! a sudden lustre ran across On every side athwart the spacious forest, Such that it made me doubt if it were lightning. +dev-clean/84/121550/84_121550_000156_000000|Nor prayer for inspiration me availed, By means of which in dreams and otherwise I called him back, so little did he heed them. +dev-clean/84/121550/84_121550_000247_000000|Thus Beatrice; and I, who at the feet Of her commandments all devoted was, My mind and eyes directed where she willed. +dev-clean/8842/302203/8842_302203_000001_000001|And I remember that on the ninth day, being overcome with intolerable pain, a thought came into my mind concerning my lady: but when it had a little nourished this thought, my mind returned to its brooding over mine enfeebled body. diff --git a/FlashSR/BigVGAN/LibriTTS/dev-other.txt b/FlashSR/BigVGAN/LibriTTS/dev-other.txt new file mode 100644 index 0000000000000000000000000000000000000000..9952353bccd0d62f22da2c409ba6764df09bc005 --- /dev/null +++ b/FlashSR/BigVGAN/LibriTTS/dev-other.txt @@ -0,0 +1,93 @@ +dev-other/116/288045/116_288045_000003_000000|PART one +dev-other/116/288045/116_288045_000034_000001|He was only an idol. +dev-other/116/288047/116_288047_000002_000002|Observing the sun, the moon, and the stars overhead, the primitive man wished to account for them. +dev-other/116/288048/116_288048_000001_000000|Let me now give an idea of the method I propose to follow in the study of this subject. +dev-other/116/288048/116_288048_000020_000003|Leaving out Judas, and counting Matthias, who was elected in his place, we have thirteen apostles. +dev-other/1255/138279/1255_138279_000012_000000|"One. +dev-other/1255/138279/1255_138279_000049_000001|Will it be by banns or license?" +dev-other/1255/74899/1255_74899_000020_000000|"Pardon me. +dev-other/1255/90407/1255_90407_000006_000001|But, as the rain gave not the least sign of cessation, he observed: 'I think we shall have to go back.' +dev-other/1255/90407/1255_90407_000039_000002|Into it they plodded without pause, crossing the harbour bridge about midnight, wet to the skin. +dev-other/1255/90413/1255_90413_000023_000001|'Now what the devil this means I cannot tell,' he said to himself, reflecting stock still for a moment on the stairs. +dev-other/1585/131718/1585_131718_000025_000009|Edison marginalized documents extensively. +dev-other/1630/141772/1630_141772_000000_000002|Suddenly he again felt that he was alive and suffering from a burning, lacerating pain in his head. +dev-other/1630/141772/1630_141772_000039_000000|The quiet home life and peaceful happiness of Bald Hills presented itself to him. +dev-other/1630/73710/1630_73710_000019_000003|I almost wish papa would return, though I dread to see him. +dev-other/1630/96099/1630_96099_000033_000001|Why did you not follow him? +dev-other/1650/157641/1650_157641_000037_000001|mr w m +dev-other/1650/173551/1650_173551_000025_000000|Pierre went into that gloomy study which he had entered with such trepidation in his benefactor's lifetime. +dev-other/1651/136854/1651_136854_000046_000005|I have, however, this of gratitude, that I think of you with regard, when I do not, perhaps, give the proofs which I ought, of being, Sir, +dev-other/1686/142278/1686_142278_000015_000000|'No! not doubts as to religion; not the slightest injury to that.' He paused. +dev-other/1686/142278/1686_142278_000042_000001|Margaret was nearly upset again into a burst of crying. +dev-other/1701/141759/1701_141759_000001_000001|Not till midwinter was the count at last handed a letter addressed in his son's handwriting. +dev-other/1701/141759/1701_141759_000048_000000|"Why should you be ashamed?" +dev-other/1701/141760/1701_141760_000013_000003|"I only sent you the note yesterday by Bolkonski-an adjutant of Kutuzov's, who's a friend of mine. +dev-other/1701/141760/1701_141760_000056_000000|In spite of Prince Andrew's disagreeable, ironical tone, in spite of the contempt with which Rostov, from his fighting army point of view, regarded all these little adjutants on the staff of whom the newcomer was evidently one, Rostov felt confused, blushed, and became silent. +dev-other/2506/13150/2506_13150_000022_000000|--Nay-if you don't believe me, you may read the chapter for your pains. +dev-other/3660/172182/3660_172182_000012_000007|And a year, and a second, and a third, he proceeded thus, until his fame had flown over the face of the kingdom. +dev-other/3660/172183/3660_172183_000011_000000|So the maiden went forward, keeping in advance of Geraint, as he had desired her; and it grieved him as much as his wrath would permit, to see a maiden so illustrious as she having so much trouble with the care of the horses. +dev-other/3660/172183/3660_172183_000019_000040|Come with me to the court of a son in law of my sister, which is near here, and thou shalt have the best medical assistance in the kingdom." +dev-other/3660/6517/3660_6517_000036_000002|Bright sunshine. +dev-other/3660/6517/3660_6517_000059_000005|Not a single one has lost his good spirits. +dev-other/3663/172005/3663_172005_000022_000000|She must cross the Slide Brook valley, if possible, and gain the mountain opposite. +dev-other/3663/172528/3663_172528_000016_000008|He had been brought by my very dear friend Luca Martini, who passed the larger portion of the day with me. +dev-other/3915/57461/3915_57461_000018_000001|In a fit of madness I was tempted to kill and rob you. +dev-other/3915/98647/3915_98647_000018_000006|Thus the old custom is passing away. +dev-other/4323/13259/4323_13259_000009_000011|What would Jesus do? +dev-other/4323/13259/4323_13259_000020_000003|It seems she had been recently converted during the evangelist's meetings, and was killed while returning from one of the meetings in company with other converts and some of her friends. +dev-other/4323/18416/4323_18416_000019_000001|So she was asked to sing at musicales and receptions without end, until Alexia exclaimed at last, "They are all raving, stark mad over her, and it's all Polly's own fault, the whole of it." +dev-other/4323/18416/4323_18416_000050_000000|"I know, child; you think your old Grandpapa does just about right," said mr King soothingly, and highly gratified. +dev-other/4323/18416/4323_18416_000079_000002|"And I can't tolerate any thoughts I cannot speak." +dev-other/4323/55228/4323_55228_000028_000000|"Pete told you that I didn't care for any girl, only to paint?" demanded Bertram, angry and mystified. +dev-other/4323/55228/4323_55228_000071_000000|There was another silence. +dev-other/4570/102353/4570_102353_000001_000000|CHAPTER four. +dev-other/4570/14911/4570_14911_000009_000002|EYES-Brown, dark hazel or hazel, not deep set nor bulgy, and with a mild expression. +dev-other/4570/56594/4570_56594_000012_000000|"'No,' says the gentleman. +dev-other/4831/18525/4831_18525_000028_000000|"Oh! isn't it 'Oats, Peas, Beans, and Barley grow'?" cried Polly, as they watched them intently. +dev-other/4831/18525/4831_18525_000078_000001|"I want to write, too, I do," she cried, very much excited. +dev-other/4831/18525/4831_18525_000122_000000|"O dear me!" exclaimed Polly, softly, for she couldn't even yet get over that dreadful beginning. +dev-other/4831/25894/4831_25894_000022_000003|The other days were very much like this; sometimes they made more, sometimes less, but Tommo always 'went halves;' and Tessa kept on, in spite of cold and weariness, for her plans grew as her earnings increased, and now she hoped to get useful things, instead of candy and toys alone. +dev-other/4831/29134/4831_29134_000001_000000|The session was drawing toward its close. +dev-other/4831/29134/4831_29134_000018_000000|"So this poor little boy grew up to be a man, and had to go out in the world, far from home and friends to earn his living. +dev-other/5543/27761/5543_27761_000019_000000|Her mother went to hide. +dev-other/5543/27761/5543_27761_000065_000000|"Agathya says so, madam," answered Fedosya; "it's she that knows." +dev-other/5543/27761/5543_27761_000107_000000|"Sima, my dear, don't agitate yourself," said Sergey Modestovich in a whisper. +dev-other/5849/50873/5849_50873_000026_000000|"He has promised to do so." +dev-other/5849/50873/5849_50873_000074_000000|"The boy did it! +dev-other/5849/50962/5849_50962_000010_000000|"It's a schooner," said mr Bingham to mr Minturn, "and she has a very heavy cargo." +dev-other/5849/50963/5849_50963_000009_000003|Well, it was a long, slow job to drag those heavy logs around that point, and just when we were making headway, along comes a storm that drove the schooner and canoes out of business." +dev-other/5849/50964/5849_50964_000018_000001|There were the shells to be looked after, the fish nets, besides Downy, the duck, and Snoop, the cat. +dev-other/6123/59150/6123_59150_000016_000001|He kicked him two or three times with his heel in the face. +dev-other/6123/59186/6123_59186_000008_000000|"Catering care" is an appalling phrase. +dev-other/6267/53049/6267_53049_000007_000001|"I'd better be putting my grey matter into that algebra instead of wasting it plotting for a party dress that I certainly can't get. +dev-other/6267/53049/6267_53049_000045_000001|I am named after her." +dev-other/6267/65525/6267_65525_000018_000000|Dear mr Lincoln: +dev-other/6267/65525/6267_65525_000045_000006|You can't mistake it." +dev-other/6455/66379/6455_66379_000020_000002|(Deal, sir, if you please; better luck next time.)" +dev-other/6455/67803/6455_67803_000038_000000|"Yes," he answered. +dev-other/6467/56885/6467_56885_000012_000001|As you are so generously taking her on trust, may she never cause you a moment's regret. +dev-other/6467/97061/6467_97061_000010_000000|A terrible battle ensued, in which both kings performed prodigies of valour. +dev-other/6841/88291/6841_88291_000006_000006|One stood waiting for them to finish, a sheaf of long j h stamping irons in his hand. +dev-other/6841/88291/6841_88291_000019_000006|Cries arose in a confusion: "Marker" "Hot iron!" "Tally one!" Dust eddied and dissipated. +dev-other/6841/88294/6841_88294_000010_000003|Usually I didn't bother with his talk, for it didn't mean anything, but something in his voice made me turn. +dev-other/6841/88294/6841_88294_000048_000000|He stood there looking straight at me without winking or offering to move. +dev-other/700/122866/700_122866_000006_000003|You've been thirteen for a month, so I suppose it doesn't seem such a novelty to you as it does to me. +dev-other/700/122866/700_122866_000023_000006|Ruby Gillis is rather sentimental. +dev-other/700/122867/700_122867_000012_000004|My career is closed. +dev-other/700/122867/700_122867_000033_000003|At the end of the week Marilla said decidedly: +dev-other/700/122868/700_122868_000015_000003|mrs Lynde says that all play acting is abominably wicked." +dev-other/700/122868/700_122868_000038_000001|And Ruby is in hysterics-oh, Anne, how did you escape?" +dev-other/7601/101622/7601_101622_000018_000002|The very girls themselves set them on: +dev-other/7601/175351/7601_175351_000031_000008|Still, during the nights which followed the fifteenth of August, darkness was never profound; although the sun set, he still gave sufficient light by refraction. +dev-other/7641/96252/7641_96252_000003_000006|For these are careful only for themselves, for their own egoism, just like the bandit, from whom they are only distinguished by the absurdity of their means. +dev-other/7641/96670/7641_96670_000013_000001|The mist lifted suddenly and she saw three strangers in the palace courtyard. +dev-other/7641/96684/7641_96684_000009_000000|"What years of happiness have been mine, O Apollo, through your friendship for me," said Admetus. +dev-other/7641/96684/7641_96684_000031_000002|How noble it was of Admetus to bring him into his house and give entertainment to him while such sorrow was upon him. +dev-other/7697/105815/7697_105815_000048_000002|And they brought out the jaw bone of an ass with which Samson did such great feats, and the sling and stone with which David slew Goliath of Gath. +dev-other/8173/294714/8173_294714_000006_000001|"Don't spoil my pleasure in seeing you again by speaking of what can never be! Have you still to be told how it is that you find me here alone with my child?" +dev-other/8173/294714/8173_294714_000027_000001|What was there to prevent her from insuring her life, if she pleased, and from so disposing of the insurance as to give Van Brandt a direct interest in her death? +dev-other/8254/115543/8254_115543_000034_000000|"Yes, and how he orders every one about him. +dev-other/8254/84205/8254_84205_000029_000000|"I'm not afraid of them hitting me, my lad," said Griggs confidently. "Being shot at by fellows with bows and arrows sounds bad enough, but there's not much risk here." +dev-other/8254/84205/8254_84205_000073_000000|"Right; I do, neighbour, and it's very handsome of you to offer me the chance to back out. +dev-other/8288/274162/8288_274162_000023_000000|"Exactly. +dev-other/8288/274162/8288_274162_000078_000000|"So much the worse. diff --git a/FlashSR/BigVGAN/LibriTTS/test-clean.txt b/FlashSR/BigVGAN/LibriTTS/test-clean.txt new file mode 100644 index 0000000000000000000000000000000000000000..5bbfab4da31088b3f42dcf2e15238e046d1781dc --- /dev/null +++ b/FlashSR/BigVGAN/LibriTTS/test-clean.txt @@ -0,0 +1,97 @@ +test-clean/1089/134686/1089_134686_000001_000001|He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fattened sauce. Stuff it into you, his belly counselled him. +test-clean/1089/134686/1089_134686_000020_000001|We can scut the whole hour. +test-clean/1089/134691/1089_134691_000004_000001|Yet her mistrust pricked him more keenly than his father's pride and he thought coldly how he had watched the faith which was fading down in his soul ageing and strengthening in her eyes. +test-clean/1089/134691/1089_134691_000027_000004|Now, at the name of the fabulous artificer, he seemed to hear the noise of dim waves and to see a winged form flying above the waves and slowly climbing the air. +test-clean/1188/133604/1188_133604_000018_000002|There are just four touches-fine as the finest penmanship-to do that beak; and yet you will find that in the peculiar paroquettish mumbling and nibbling action of it, and all the character in which this nibbling beak differs from the tearing beak of the eagle, it is impossible to go farther or be more precise. +test-clean/121/121726/121_121726_000046_000003|Tied to a woman. +test-clean/121/127105/121_127105_000024_000000|He laughed for the first time. +test-clean/1284/1180/1284_1180_000001_000000|The Crooked Magician +test-clean/1284/1181/1284_1181_000005_000000|The head of the Patchwork Girl was the most curious part of her. +test-clean/1320/122612/1320_122612_000019_000005|It is true that the horses are here, but the Hurons are gone; let us, then, hunt for the path by which they parted." +test-clean/1320/122612/1320_122612_000056_000002|Then he reappeared, creeping along the earth, from which his dress was hardly distinguishable, directly in the rear of his intended captive. +test-clean/1580/141083/1580_141083_000012_000000|"The first page on the floor, the second in the window, the third where you left it," said he. +test-clean/1580/141083/1580_141083_000041_000003|Above were three students, one on each story. +test-clean/1580/141083/1580_141083_000063_000001|Holmes held it out on his open palm in the glare of the electric light. +test-clean/1580/141083/1580_141083_000110_000001|Where were you when you began to feel bad?" +test-clean/1580/141084/1580_141084_000024_000002|Pencils, too, and knives-all was satisfactory. +test-clean/1580/141084/1580_141084_000060_000001|"I frankly admit that I am unable to prove it. +test-clean/1580/141084/1580_141084_000085_000000|"Good heavens! have you nothing to add?" cried Soames. +test-clean/1995/1826/1995_1826_000022_000001|Miss Taylor did not know much about cotton, but at least one more remark seemed called for. +test-clean/1995/1836/1995_1836_000016_000001|No, of course there was no immediate danger; but when people were suddenly thrust beyond their natural station, filled with wild ideas and impossible ambitions, it meant terrible danger to Southern white women. +test-clean/1995/1837/1995_1837_000024_000000|He heard that she was down stairs and ran to meet her with beating heart. +test-clean/2300/131720/2300_131720_000016_000005|Having travelled around the world, I had cultivated an indifference to any special difficulties of that kind. +test-clean/2300/131720/2300_131720_000030_000005|I telephoned again, and felt something would happen, but fortunately it did not. +test-clean/237/126133/237_126133_000002_000004|It got to be noticed finally; and one and all redoubled their exertions to make everything twice as pleasant as ever! +test-clean/237/126133/237_126133_000049_000000|But the chubby face didn't look up brightly, as usual: and the next moment, without a bit of warning, Phronsie sprang past them all, even Polly, and flung herself into mr King's arms, in a perfect torrent of sobs. +test-clean/237/134493/237_134493_000008_000003|Alexandra lets you sleep late. +test-clean/237/134500/237_134500_000001_000001|Frank sat up until a late hour reading the Sunday newspapers. +test-clean/237/134500/237_134500_000014_000000|"I don't know all of them, but I know lindens are. +test-clean/237/134500/237_134500_000034_000000|She sighed despondently. +test-clean/260/123286/260_123286_000019_000002|Therefore don't talk to me about views and prospects." +test-clean/260/123286/260_123286_000049_000005|He shakes his head negatively. +test-clean/260/123288/260_123288_000016_000002|It rushes on from the farthest recesses of the vast cavern. +test-clean/260/123288/260_123288_000043_000001|I could just see my uncle at full length on the raft, and Hans still at his helm and spitting fire under the action of the electricity which has saturated him. +test-clean/2830/3979/2830_3979_000007_000000|PREFACE +test-clean/2830/3980/2830_3980_000018_000001|Humble man that he was, he will not now take a back seat. +test-clean/2961/961/2961_961_000004_000037|Then your city did bravely, and won renown over the whole earth. +test-clean/2961/961/2961_961_000023_000003|But violent as were the internal and alimentary fluids, the tide became still more violent when the body came into contact with flaming fire, or the solid earth, or gliding waters, or the stormy wind; the motions produced by these impulses pass through the body to the soul and have the name of sensations. +test-clean/3570/5694/3570_5694_000009_000003|The canon of reputability is at hand and seizes upon such innovations as are, according to its standard, fit to survive. +test-clean/3570/5695/3570_5695_000001_000003|But the middle class wife still carries on the business of vicarious leisure, for the good name of the household and its master. +test-clean/3570/5695/3570_5695_000009_000005|Considered by itself simply-taken in the first degree-this added provocation to which the artisan and the urban laboring classes are exposed may not very seriously decrease the amount of savings; but in its cumulative action, through raising the standard of decent expenditure, its deterrent effect on the tendency to save cannot but be very great. +test-clean/3570/5696/3570_5696_000011_000006|For this is the basis of award of the instinct of workmanship, and that instinct is the court of final appeal in any question of economic truth or adequacy. +test-clean/3729/6852/3729_6852_000004_000003|In order to please her, I spoke to her of the Abbe Conti, and I had occasion to quote two lines of that profound writer. +test-clean/4077/13754/4077_13754_000002_000000|The troops, once in Utah, had to be provisioned; and everything the settlers could spare was eagerly bought at an unusual price. The gold changed hands. +test-clean/4446/2271/4446_2271_000003_000004|There's everything in seeing Hilda while she's fresh in a part. +test-clean/4446/2271/4446_2271_000020_000001|Lady Westmere is very fond of Hilda." +test-clean/4446/2273/4446_2273_000008_000002|I've no need for fine clothes in Mac's play this time, so I can afford a few duddies for myself. +test-clean/4446/2273/4446_2273_000027_000004|She did my blouses beautifully the last time I was there, and was so delighted to see me again. +test-clean/4446/2273/4446_2273_000046_000001|"Aren't you afraid to let the wind low like that on your neck? +test-clean/4446/2275/4446_2275_000013_000000|Hilda was pale by this time, and her eyes were wide with fright. +test-clean/4446/2275/4446_2275_000038_000006|"You want to tell me that you can only see me like this, as old friends do, or out in the world among people? +test-clean/4507/16021/4507_16021_000011_000000|It engenders a whole world, la pegre, for which read theft, and a hell, la pegrenne, for which read hunger. +test-clean/4507/16021/4507_16021_000030_000001|Facts form one of these, and ideas the other. +test-clean/4970/29093/4970_29093_000010_000000|Delightful illusion of paint and tinsel and silk attire, of cheap sentiment and high and mighty dialogue! +test-clean/4970/29093/4970_29093_000047_000000|"Never mind the map. +test-clean/4970/29095/4970_29095_000021_000000|"I will practice it." +test-clean/4970/29095/4970_29095_000055_000002|He took it with him from the Southern Hotel, when he went to walk, and read it over and again in an unfrequented street as he stumbled along. +test-clean/4992/41797/4992_41797_000014_000002|He keeps the thou shalt not commandments first rate, Hen Lord does! +test-clean/4992/41806/4992_41806_000020_000001|Thou who settest the solitary in families, bless the life that is sheltered here. +test-clean/5105/28241/5105_28241_000004_000004|The late astounding events, however, had rendered Procope manifestly uneasy, and not the less so from his consciousness that the count secretly partook of his own anxiety. +test-clean/5142/33396/5142_33396_000004_000004|At the prow I carved the head with open mouth and forked tongue thrust out. +test-clean/5142/33396/5142_33396_000039_000000|"The thralls were bringing in a great pot of meat. +test-clean/5142/36377/5142_36377_000013_000003|I liked Naomi Colebrook at first sight; liked her pleasant smile; liked her hearty shake of the hand when we were presented to each other. +test-clean/5639/40744/5639_40744_000003_000006|Mother! dear father! do you hear me? +test-clean/5639/40744/5639_40744_000022_000000|Just then Leocadia came to herself, and embracing the cross seemed changed into a sea of tears, and the gentleman remained in utter bewilderment, until his wife had repeated to him, from beginning to end, Leocadia's whole story; and he believed it, through the blessed dispensation of Heaven, which had confirmed it by so many convincing testimonies. +test-clean/5683/32865/5683_32865_000018_000000|Well, it was pretty-French, I dare say-a little set of tablets-a toy-the cover of enamel, studded in small jewels, with a slender border of symbolic flowers, and with a heart in the centre, a mosaic of little carbuncles, rubies, and other red and crimson stones, placed with a view to light and shade. +test-clean/5683/32866/5683_32866_000005_000000|'Did you see that?' said Wylder in my ear, with a chuckle; and, wagging his head, he added, rather loftily for him, 'Miss Brandon, I reckon, has taken your measure, Master Stanley, as well as i I wonder what the deuce the old dowager sees in him. +test-clean/5683/32866/5683_32866_000047_000002|I was not a bit afraid of being found out. +test-clean/5683/32879/5683_32879_000036_000002|Be he near, or be he far, I regard his very name with horror.' +test-clean/6829/68769/6829_68769_000011_000000|So as soon as breakfast was over the next morning Beth and Kenneth took one of the automobiles, the boy consenting unwillingly to this sort of locomotion because it would save much time. +test-clean/6829/68769/6829_68769_000051_000001|One morning she tried to light the fire with kerosene, and lost her sight. +test-clean/6829/68769/6829_68769_000089_000001|Why should you do all this?" +test-clean/6829/68771/6829_68771_000018_000003|A speakers' stand, profusely decorated, had been erected on the lawn, and hundreds of folding chairs provided for seats. +test-clean/6930/75918/6930_75918_000000_000001|Night. +test-clean/6930/81414/6930_81414_000041_000001|Here is his scarf, which has evidently been strained, and on it are spots of blood, while all around are marks indicating a struggle. +test-clean/7021/79740/7021_79740_000010_000006|I observe that, when you both wish for the same thing, you don't quarrel for it and try to pull it away from one another; but one waits like a lady until the other has done with it. +test-clean/7021/85628/7021_85628_000017_000000|"I am going to the court ball," answered Anders. +test-clean/7127/75946/7127_75946_000022_000002|It is necessary, therefore, that he should comply." +test-clean/7127/75946/7127_75946_000061_000001|Disdainful of a success of which Madame showed no acknowledgement, he thought of nothing but boldly regaining the marked preference of the princess. +test-clean/7127/75947/7127_75947_000035_000000|"Quite true, and I believe you are right. +test-clean/7176/88083/7176_88083_000002_000003|He was too imposing in appearance, too gorgeous in apparel, too bold and vigilant in demeanor to be so misunderstood. +test-clean/7176/88083/7176_88083_000017_000000|Immediately over his outstretched gleaming head flew the hawk. +test-clean/7176/92135/7176_92135_000011_000000|And, so on in the same vein for some thirty lines. +test-clean/7176/92135/7176_92135_000074_000001|Tea, please, Matthews. +test-clean/7729/102255/7729_102255_000011_000003|The Free State Hotel served as barracks. +test-clean/7729/102255/7729_102255_000028_000009|They were squads of Kansas militia, companies of "peaceful emigrants," or gangs of irresponsible outlaws, to suit the chance, the whim, or the need of the moment. +test-clean/8230/279154/8230_279154_000003_000002|In the present lecture I shall attempt the analysis of memory knowledge, both as an introduction to the problem of knowledge in general, and because memory, in some form, is presupposed in almost all other knowledge. +test-clean/8230/279154/8230_279154_000013_000003|One of these is context. +test-clean/8230/279154/8230_279154_000027_000000|A further stage is RECOGNITION. +test-clean/8455/210777/8455_210777_000022_000003|And immediately on his sitting down, there got up a gentleman to whom I had not been introduced before this day, and gave the health of Mrs Neverbend and the ladies of Britannula. +test-clean/8455/210777/8455_210777_000064_000001|Government that he shall be treated with all respect, and that those honours shall be paid to him which are due to the President of a friendly republic. +test-clean/8463/287645/8463_287645_000023_000001|For instance, Jacob Taylor was noticed on the record book as being twenty three years of age, and the name of his master was entered as "William Pollit;" but as Jacob had never been allowed to learn to read, he might have failed in giving a correct pronunciation of the name. +test-clean/8463/294825/8463_294825_000048_000000|- CENTIMETER Roughly two fifths of an inch +test-clean/8463/294828/8463_294828_000046_000001|Conseil did them in a flash, and I was sure the lad hadn't missed a thing, because he classified shirts and suits as expertly as birds and mammals. +test-clean/8555/284447/8555_284447_000018_000002|The poor Queen, by the way, was seldom seen, as she passed all her time playing solitaire with a deck that was one card short, hoping that before she had lived her entire six hundred years she would win the game. +test-clean/8555/284447/8555_284447_000049_000000|Now, indeed, the Boolooroo was as angry as he was amazed. +test-clean/8555/284449/8555_284449_000039_000000|When the courtiers and the people assembled saw the goat they gave a great cheer, for the beast had helped to dethrone their wicked Ruler. +test-clean/8555/292519/8555_292519_000041_000000|She was alone that night. He had broken into her courtyard. Above the gurgling gutters he heard- surely- a door unchained? diff --git a/FlashSR/BigVGAN/LibriTTS/test-other.txt b/FlashSR/BigVGAN/LibriTTS/test-other.txt new file mode 100644 index 0000000000000000000000000000000000000000..930b528eb7a3e92597cd1d5ad3d2519143ad2051 --- /dev/null +++ b/FlashSR/BigVGAN/LibriTTS/test-other.txt @@ -0,0 +1,103 @@ +test-other/1688/142285/1688_142285_000000_000000|'Margaret!' said mr Hale, as he returned from showing his guest downstairs; 'I could not help watching your face with some anxiety, when mr Thornton made his confession of having been a shop boy. +test-other/1688/142285/1688_142285_000046_000000|'No, mamma; that Anne Buckley would never have done.' +test-other/1998/15444/1998_15444_000012_000000|Simple filtration will sometimes suffice to separate the required substance; in other cases dialysis will be necessary, in order that crystalloid substances may be separated from colloid bodies. +test-other/1998/29454/1998_29454_000021_000001|Fried eggs and bacon-he had one egg and the man had three-bread and butter-and if the bread was thick, so was the butter-and as many cups of tea as you liked to say thank you for. +test-other/1998/29454/1998_29454_000053_000001|It almost looked, Dickie thought, as though he had brought them out for some special purpose. +test-other/1998/29455/1998_29455_000022_000000|It was a wonderful day. +test-other/1998/29455/1998_29455_000082_000003|But 'e's never let it out." +test-other/2414/128292/2414_128292_000003_000000|"What!" said he, "have not the most ludicrous things always happened to us old anchorites and saints? +test-other/2609/156975/2609_156975_000036_000004|The cruel fate of his people and the painful experience in Egypt that had driven him into the wilderness prepared his mind to receive this training. +test-other/3005/163389/3005_163389_000017_000001|And they laughed all the time, and that made the duke mad; and everybody left, anyway, before the show was over, but one boy which was asleep. +test-other/3005/163390/3005_163390_000023_000021|S'pose people left money laying around where he was what did he do? +test-other/3005/163391/3005_163391_000021_000000|"It's a pretty long journey. +test-other/3005/163399/3005_163399_000013_000002|When we got there she set me down in a split bottomed chair, and set herself down on a little low stool in front of me, holding both of my hands, and says: +test-other/3005/163399/3005_163399_000045_000000|He sprung to the window at the head of the bed, and that give mrs Phelps the chance she wanted. +test-other/3080/5040/3080_5040_000000_000010|You have no such ladies in Ireland? +test-other/3331/159605/3331_159605_000006_000002|I could do so much for all at home how I should enjoy that!" And Polly let her thoughts revel in the luxurious future her fancy painted. +test-other/3331/159605/3331_159605_000082_000000|"Who got up that nice idea, I should like to know?" demanded Polly, as Fanny stopped for breath. +test-other/3528/168656/3528_168656_000003_000003|She told wonders of the Abbey of Fontevrault,--that it was like a city, and that there were streets in the monastery. +test-other/3528/168669/3528_168669_000030_000000|A silence ensued. +test-other/3528/168669/3528_168669_000075_000000|"Like yourself, reverend Mother." +test-other/3528/168669/3528_168669_000123_000000|"But the commissary of police-" +test-other/3528/168669/3528_168669_000137_000000|"That is well." +test-other/3528/168669/3528_168669_000164_000008|I shall have my lever. +test-other/3538/142836/3538_142836_000021_000003|However, as late as the reigns of our two last Georges, fabulous sums were often expended upon fanciful desserts. +test-other/3538/163619/3538_163619_000054_000000|'Now he says that you are to make haste and throw yourself overboard,' answered the step mother. +test-other/3538/163622/3538_163622_000069_000000|So they travelled onwards again, for many and many a mile, over hill and dale. +test-other/3538/163624/3538_163624_000038_000000|Then Sigurd went down into that deep place, and dug many pits in it, and in one of the pits he lay hidden with his sword drawn. +test-other/367/130732/367_130732_000002_000001|Probably nowhere in San Francisco could one get lobster better served than in the Old Delmonico restaurant of the days before the fire. +test-other/3764/168670/3764_168670_000003_000000|"But you, Father Madeleine?" +test-other/3764/168670/3764_168670_000043_000000|"Yes." +test-other/3764/168670/3764_168670_000083_000005|He grumbled:-- +test-other/3764/168671/3764_168671_000012_000003|He did what he liked with him. +test-other/3764/168671/3764_168671_000046_000000|"Comrade!" cried Fauchelevent. +test-other/3997/180294/3997_180294_000023_000000|Then, when God allows love to a courtesan, that love, which at first seems like a pardon, becomes for her almost without penitence. +test-other/3997/180294/3997_180294_000065_000001|The count will be coming back, and there is nothing to be gained by his finding you here." +test-other/3997/180297/3997_180297_000034_000004|For these people we have to be merry when they are merry, well when they want to sup, sceptics like themselves. +test-other/3997/182399/3997_182399_000014_000003|Oh, my, no! +test-other/4198/61336/4198_61336_000000_000003|It is significant to note in this connection that the new king was an unswerving adherent of the cult of Ashur, by the adherents of which he was probably strongly supported. +test-other/4198/61336/4198_61336_000033_000001|Nabonassar had died and was succeeded by his son Nabu nadin zeri, who, after reigning for two years, was slain in a rebellion. +test-other/4294/14317/4294_14317_000022_000011|I do not condescend to smite you. He looked at me submissively and said nothing. +test-other/4294/35475/4294_35475_000018_000001|At last they reached a wide chasm that bounded the Ogre's domain. +test-other/4294/35475/4294_35475_000050_000002|They said, "We are only waiting to lay some wily plan to capture the Ogre." +test-other/4294/9934/4294_9934_000025_000000|"Gold; here it is." +test-other/4350/10919/4350_10919_000006_000000|"Immediately, princess. +test-other/4350/9170/4350_9170_000005_000001|Authority, in the sense in which the word is ordinarily understood, is a means of forcing a man to act in opposition to his desires. +test-other/4350/9170/4350_9170_000056_000000|But the fatal significance of universal military service, as the manifestation of the contradiction inherent in the social conception of life, is not only apparent in that. +test-other/4852/28311/4852_28311_000031_000001|After a step or two, not finding his friend beside him, he turned. +test-other/4852/28319/4852_28319_000013_000002|mr Wicker waited patiently beside him for a few moments for Chris to get up his courage. +test-other/533/1066/533_1066_000008_000000|"I mean," he persisted, "do you feel as though you could go through with something rather unusual?" +test-other/533/131562/533_131562_000018_000000|mr Huntingdon then went up stairs. +test-other/5442/41168/5442_41168_000002_000001|Sergey Ivanovitch, waiting till the malignant gentleman had finished speaking, said that he thought the best solution would be to refer to the act itself, and asked the secretary to find the act. +test-other/5442/41169/5442_41169_000003_000000|"He's such a blackguard! +test-other/5442/41169/5442_41169_000030_000000|"And with what he made he'd increase his stock, or buy some land for a trifle, and let it out in lots to the peasants," Levin added, smiling. He had evidently more than once come across those commercial calculations. +test-other/5484/24317/5484_24317_000040_000006|Let us hope that you will make this three leaved clover the luck promising four leaved one. +test-other/5484/24318/5484_24318_000015_000002|The blood of these innocent men would be on his head if he did not listen to her representations. +test-other/5484/24318/5484_24318_000068_000001|He was appearing before his companions only to give truth its just due. +test-other/5764/299665/5764_299665_000041_000004|He saw the seeds that man had planted wither and perish, but he sent no rain. +test-other/5764/299665/5764_299665_000070_000000|Think of the egotism of a man who believes that an infinite being wants his praise! +test-other/5764/299665/5764_299665_000102_000000|The first stone is that matter-substance-cannot be destroyed, cannot be annihilated. +test-other/5764/299665/5764_299665_000134_000000|You cannot reform these people with tracts and talk. +test-other/6070/63485/6070_63485_000025_000003|Hand me the cash, and I will hand you the pocketbook." +test-other/6070/86744/6070_86744_000027_000000|"Have you bachelor's apartments there? +test-other/6070/86745/6070_86745_000001_000002|Two windows only of the pavilion faced the street; three other windows looked into the court, and two at the back into the garden. +test-other/6128/63240/6128_63240_000012_000002|Neither five nor fifteen, and yet not ten exactly, but either nine or eleven. +test-other/6128/63240/6128_63240_000042_000002|mrs Luna explained to her sister that her freedom of speech was caused by his being a relation-though, indeed, he didn't seem to know much about them. +test-other/6128/63244/6128_63244_000002_000000|"I can't talk to those people, I can't!" said Olive Chancellor, with a face which seemed to plead for a remission of responsibility. +test-other/6432/63722/6432_63722_000026_000000|"Not the least in the world-not as much as you do," was the cool answer. +test-other/6432/63722/6432_63722_000050_000004|Queen Elizabeth was very fond of watches and clocks, and her friends, knowing that, used to present her with beautiful specimens. Some of the watches of her day were made in the form of crosses, purses, little books, and even skulls." +test-other/6432/63722/6432_63722_000080_000003|When it does it will create a sensation." +test-other/6432/63723/6432_63723_000026_000000|"No; but he will, or I'll sue him and get judgment. +test-other/6432/63723/6432_63723_000057_000000|"Then for the love of-" +test-other/6432/63723/6432_63723_000080_000000|"Hello, Harry! +test-other/6938/70848/6938_70848_000046_000003|Show me the source!" +test-other/6938/70848/6938_70848_000104_000000|With biting sarcasm he went on to speak of the Allied diplomats, till then contemptuous of Russia's invitation to an armistice, which had been accepted by the Central Powers. +test-other/7105/2330/7105_2330_000021_000000|"He won't go unless he has a brass band. +test-other/7105/2340/7105_2340_000015_000001|We feel that we must live on cream for the rest of our lives. +test-other/7902/96591/7902_96591_000008_000001|I did not come to frighten you; you frightened me." +test-other/7902/96591/7902_96591_000048_000000|"No," he thought to himself, "I don't believe they would kill me, but they would knock me about." +test-other/7902/96592/7902_96592_000024_000001|Once out of that room he could ran, and by daylight the smugglers dare not hunt him down. +test-other/7902/96592/7902_96592_000063_000000|"What for?" cried Ram. +test-other/7902/96594/7902_96594_000014_000001|These fellows are very cunning, but we shall be too many for them one of these days." +test-other/7902/96594/7902_96594_000062_000001|Keep a sharp look out on the cliff to see if Mr Raystoke is making signals for a boat. +test-other/7902/96595/7902_96595_000039_000000|The man shook his head, and stared as if he didn't half understand the drift of what was said. +test-other/7975/280057/7975_280057_000009_000000|Naturally we were Southerners in sympathy and in fact. +test-other/7975/280057/7975_280057_000025_000004|On reaching the camp the first person I saw whom I knew was Cole Younger. +test-other/7975/280076/7975_280076_000013_000001|I will give you this outline and sketch of my whereabouts and actions at the time of certain robberies with which I am charged. +test-other/7975/280084/7975_280084_000007_000000|But between the time we broke camp and the time they reached the bridge the three who went ahead drank a quart of whisky, and there was the initial blunder at Northfield. +test-other/7975/280085/7975_280085_000005_000002|Some of the boys wanted to kill him, on the theory that "dead men tell no tales," while others urged binding him and leaving him in the woods. +test-other/8131/117016/8131_117016_000005_000000|The Stonewall gang numbered perhaps five hundred. +test-other/8131/117016/8131_117016_000025_000001|"And don't let them get away!" +test-other/8131/117016/8131_117016_000047_000006|I can always go back to Earth, and I'll try to take you along. +test-other/8131/117017/8131_117017_000005_000000|Gordon hit the signal switch, and the Marspeaker let out a shrill whistle. +test-other/8131/117017/8131_117017_000020_000003|There's no graft out here." +test-other/8131/117029/8131_117029_000007_000002|Wrecks were being broken up, with salvageable material used for newer homes. Gordon came to a row of temporary bubbles, individual dwellings built like the dome, but opaque for privacy. +test-other/8131/117029/8131_117029_000023_000004|But there'll be pushers as long as weak men turn to drugs, and graft as long as voters allow the thing to get out of their hands. +test-other/8188/269288/8188_269288_000018_000000|A few moments later there came a tap at the door. +test-other/8188/269288/8188_269288_000053_000001|"Do you want to kill me? +test-other/8188/269290/8188_269290_000035_000001|"But now, Leslie, what is the trouble? +test-other/8188/269290/8188_269290_000065_000000|"I don't think she is quite well," replied Leslie. +test-other/8280/266249/8280_266249_000030_000000|The ladies were weary, and retired to their state rooms shortly after tea, but the gentlemen sought the open air again and paced the deck for some time. +test-other/8280/266249/8280_266249_000113_000000|It was the last game of cards for that trip. +test-other/8461/278226/8461_278226_000026_000000|Laura thanked the French artist and then took her husband's arm and walked away with him. +test-other/8461/281231/8461_281231_000029_000002|Before long the towering flames had surmounted every obstruction, and rose to the evening skies one huge and burning beacon, seen far and wide through the adjacent country; tower after tower crashed down, with blazing roof and rafter. diff --git a/FlashSR/BigVGAN/LibriTTS/val-full.txt b/FlashSR/BigVGAN/LibriTTS/val-full.txt new file mode 100644 index 0000000000000000000000000000000000000000..e06cd55f99d1b7835af616420d8f026eab221134 --- /dev/null +++ b/FlashSR/BigVGAN/LibriTTS/val-full.txt @@ -0,0 +1,119 @@ +train-clean-100/103/1241/103_1241_000000_000001|matthew Cuthbert is surprised +train-clean-100/1594/135914/1594_135914_000033_000001|He told them, that having taken refuge in a small village, he there fell sick; that some charitable peasants had taken care of him, but finding he did not recover, a camel driver had undertaken to carry him to the hospital at Bagdad. +train-clean-100/233/155990/233_155990_000018_000002|I did, however, receive aid from the Emperor of Germany. +train-clean-100/3240/131231/3240_131231_000041_000003|Some persons, thinking them to be sea fishes, placed them in salt water, according to mr Roberts. +train-clean-100/40/222/40_222_000026_000000|"No, read it yourself," cried Catherine, whose second thoughts were clearer. +train-clean-100/4406/16882/4406_16882_000014_000002|Then they set me upon a horse with my wounded child in my lap, and there being no furniture upon the horse's back, as we were going down a steep hill we both fell over the horse's head, at which they, like inhumane creatures, laughed, and rejoiced to see it, though I thought we should there have ended our days, as overcome with so many difficulties. +train-clean-100/5393/19218/5393_19218_000115_000000|"Where is it going then?" +train-clean-100/6147/34606/6147_34606_000013_000008|One was "a dancing master;" that is to say he made the rustics frisk about by pricking the calves of their legs with the point of his sword. +train-clean-100/6848/76049/6848_76049_000003_000007|But suppose she was not all ordinary female person.... +train-clean-100/7505/258964/7505_258964_000026_000007|During the Boer War horses and mules rose in price in the United States on account of British purchases. +train-clean-100/831/130739/831_130739_000015_000000|But enough of these revelations. +train-clean-100/887/123291/887_123291_000028_000000|Here the Professor laid hold of the fossil skeleton, and handled it with the skill of a dexterous showman. +train-clean-360/112/123216/112_123216_000035_000009|The wonderful day had come and Roy's violets had no place in it. +train-clean-360/1323/149236/1323_149236_000007_000004|It was vain to hope that mere words would quiet a nation which had not, in any age, been very amenable to control, and which was now agitated by hopes and resentments, such as great revolutions, following great oppressions, naturally engender. +train-clean-360/1463/134465/1463_134465_000058_000000|Both Sandy and I began to laugh. +train-clean-360/1748/1562/1748_1562_000067_000000|"Oh, Pocket, Pocket," said I; but by this time the party which had gone towards the house, rushed out again, shouting and screaming with laughter. +train-clean-360/1914/133440/1914_133440_000014_000001|With the last twenty or thirty feet of it a deadly nausea came upon me. +train-clean-360/207/143321/207_143321_000070_000002|The canoes were not on the river bank. +train-clean-360/2272/152267/2272_152267_000003_000001|After supper the knight shared his own bed with the leper. +train-clean-360/2517/135227/2517_135227_000006_000005|As I was anxious to witness some of their purely religious ceremonies, I wished to go. +train-clean-360/2709/158074/2709_158074_000054_000000|Meanwhile the women continued to protest. +train-clean-360/2929/86777/2929_86777_000009_000000|A long silence followed; the peach, like the grapes, fell to the ground. +train-clean-360/318/124224/318_124224_000022_000010|In spite of his prejudice against Edward, he could put himself into Mr Waller's place, and see the thing from his point of view. +train-clean-360/3368/170952/3368_170952_000006_000000|And can he be fearless of death, or will he choose death in battle rather than defeat and slavery, who believes the world below to be real and terrible? +train-clean-360/3549/9203/3549_9203_000005_000004|We must hope so. There are examples. +train-clean-360/3835/178028/3835_178028_000007_000001|That day Prince Vasili no longer boasted of his protege Kutuzov, but remained silent when the commander in chief was mentioned. +train-clean-360/3994/149798/3994_149798_000005_000002|Afterward we can visit the mountain and punish the cruel magician of the Flatheads." +train-clean-360/4257/6397/4257_6397_000009_000000|At that time Nostromo had been already long enough in the country to raise to the highest pitch Captain Mitchell's opinion of the extraordinary value of his discovery. +train-clean-360/454/134728/454_134728_000133_000000|After a week of physical anguish, Unrest and pain, and feverish heat, Toward the ending day a calm and lull comes on, Three hours of peace and soothing rest of brain. +train-clean-360/4848/28247/4848_28247_000026_000002|Had he gained this arduous height only to behold the rocks carpeted with ice and snow, and reaching interminably to the far off horizon? +train-clean-360/5039/1189/5039_1189_000091_000000|The Shaggy Man sat down again and seemed well pleased. +train-clean-360/5261/19373/5261_19373_000011_000001|Some cause was evidently at work on this distant planet, causing it to disagree with its motion as calculated according to the law of gravitation. +train-clean-360/5538/70919/5538_70919_000032_000001|Only one person in the world could have laid those discoloured pearls at his door in the dead of night. The black figure in the garden, with the chiffon fluttering about its head, was Evelina Grey-or what was left of her. +train-clean-360/5712/48848/5712_48848_000060_000003|Lily for the time had been raised to a pinnacle,--a pinnacle which might be dangerous, but which was, at any rate, lofty. +train-clean-360/5935/43322/5935_43322_000050_000002|I think too-yes, I think that on the whole the ritual is impressive. +train-clean-360/6115/58433/6115_58433_000007_000002|We must run the risk." +train-clean-360/6341/64956/6341_64956_000040_000000|"Why, papa, I thought we were going to have such a nice time, and she just spoiled it all." +train-clean-360/6509/67147/6509_67147_000028_000003|It "was n't done" in England. +train-clean-360/6694/70837/6694_70837_000027_000002|There an enormous smiling sailor stopped me, and when I showed my pass, just said, "If you were Saint Michael himself, comrade, you couldn't pass here!" Through the glass of the door I made out the distorted face and gesticulating arms of a French correspondent, locked in.... +train-clean-360/6956/76046/6956_76046_000055_000001|Twelve hundred, fifteen hundred millions perhaps." +train-clean-360/7145/87280/7145_87280_000004_000003|This modern Ulysses made a masterful effort, but alas! had no ships to carry him away, and no wax with which to fill his ears. +train-clean-360/7314/77782/7314_77782_000011_000000|"Well, then, what in thunder is the matter with you?" cried the Lawyer, irritated. +train-clean-360/7525/92915/7525_92915_000034_000001|It was desperate, too, and lasted nearly all day-and it was one of the important battles of the world, although the numbers engaged in it were not large. +train-clean-360/7754/108640/7754_108640_000001_000004|Was I aware-was I fully aware of the discrepancy between us? +train-clean-360/7909/106369/7909_106369_000006_000002|And Colchian Aea lies at the edge of Pontus and of the world." +train-clean-360/8011/280922/8011_280922_000009_000000|He stretched out his hand, and all at once stroked my cheek. +train-clean-360/8176/115046/8176_115046_000027_000001|"Bless my soul, I never can understand it!" +train-clean-360/8459/292347/8459_292347_000015_000000|A woman near Gort, in Galway, says: 'There is a boy, now, of the Cloran's; but I wouldn't for the world let them think I spoke of him; it's two years since he came from America, and since that time he never went to Mass, or to church, or to fairs, or to market, or to stand on the cross roads, or to hurling, or to nothing. +train-clean-360/8699/291107/8699_291107_000003_000005|He leaned closer over it, regardless of the thin choking haze that spread about his face. In his attitude there was a rigidity of controlled excitement out of keeping with the seeming harmlessness of the experiment. +train-clean-360/8855/283242/8855_283242_000061_000000|"That couldn't be helped, grannie. +train-other-500/102/129232/102_129232_000050_000000|Is it otherwise in the newest romance? +train-other-500/1124/134775/1124_134775_000087_000001|Some of them are enclosed only by hedges, which lends a cheerful aspect to the street. +train-other-500/1239/138254/1239_138254_000010_000001|It was past twelve when all preparations were finished. +train-other-500/1373/132103/1373_132103_000056_000000|So they moved on. +train-other-500/1566/153036/1566_153036_000087_000003|You enter the river close by the trees, and then keep straight for the pile of stones, which is some fifty yards higher up, for the ford crosses the river at an angle." +train-other-500/1653/142352/1653_142352_000005_000002|If he should not come! +train-other-500/1710/133294/1710_133294_000023_000000|When the Indians were the sole inhabitants of the wilds from whence they have since been expelled, their wants were few. +train-other-500/1773/139602/1773_139602_000032_000001|When the rabbit saw that the badger was getting well, he thought of another plan by which he could compass the creature's death. +train-other-500/1920/1793/1920_1793_000037_000001|She has a little Blenheim lapdog, that she loves a thousand times more than she ever will me!" +train-other-500/2067/143535/2067_143535_000009_000002|Indeed, there, to the left, was a stone shelf with a little ledge to it three inches or so high, and on the shelf lay what I took to be a corpse; at any rate, it looked like one, with something white thrown over it. +train-other-500/2208/11020/2208_11020_000037_000001|It's at my place over there.' +train-other-500/2312/157868/2312_157868_000019_000002|I am the manager of the theatre, and I'm thundering glad that your first play has been produced at the 'New York,' sir. +train-other-500/2485/151992/2485_151992_000028_000005|At last he looked up at his wife and said, in a gentle tone: +train-other-500/2587/54186/2587_54186_000015_000000|Concerning the work as a whole he wrote to Clara while in the throes of composition: "This music now in me, and always such beautiful melodies! +train-other-500/2740/288813/2740_288813_000018_000003|But Philip had kept him apart, had banked him off, and yet drained him to the dregs. +train-other-500/2943/171001/2943_171001_000122_000000|The sound of his voice pronouncing her name aroused her. +train-other-500/3063/138651/3063_138651_000028_000000|But, as may be imagined, the unfortunate john was as much surprised by this rencounter as the other two. +train-other-500/3172/166439/3172_166439_000050_000000|And now at last was clear a thing that had puzzled greatly-the mechanism of that opening process by which sphere became oval disk, pyramid a four pointed star and-as I had glimpsed in the play of the Little Things about Norhala, could see now so plainly in the Keeper-the blocks took this inverted cruciform shape. +train-other-500/331/132019/331_132019_000038_000000|"I say, this is folly! +train-other-500/3467/166570/3467_166570_000054_000001|Does he never mention Orlando?" +train-other-500/3587/140711/3587_140711_000015_000001|O fie, mrs Jervis, said I, how could you serve me so? Besides, it looks too free both in me, and to him. +train-other-500/3675/187020/3675_187020_000026_000001|"I wonder what would be suitable? +train-other-500/3819/134146/3819_134146_000019_000001|Also the figure half hidden by the cupboard door-was a female figure, massive, and in flowing robes. +train-other-500/3912/77626/3912_77626_000003_000004|You may almost distinguish the figures on the clock that has just told the hour. +train-other-500/4015/63729/4015_63729_000058_000000|"It does." +train-other-500/413/22436/413_22436_000035_000003|I conjecture, the French squadron is bound for Malta and Alexandria, and the Spanish fleet for the attack of Minorca." +train-other-500/4218/41159/4218_41159_000028_000002|Yes? That worries Alexey. +train-other-500/4352/10940/4352_10940_000037_000002|He doesn't exist." +train-other-500/4463/26871/4463_26871_000023_000000|"I did not notice him following me," she said timidly. +train-other-500/4591/14356/4591_14356_000019_000000|"Within three days," cried the enchanter, loudly, "bring Rinaldo and Ricciardetto into the pass of Ronces Valles. +train-other-500/4738/291957/4738_291957_000000_000001|ODE ON THE SPRING. +train-other-500/4824/36029/4824_36029_000045_000003|And indeed Janet herself had taken no part in the politics, content merely to advise the combatants upon their demeanour. +train-other-500/4936/65528/4936_65528_000014_000007|I immediately responded, "Yes, they are most terrible struck on each other," and I said it in a tone that indicated I thought it a most beautiful and lovely thing that they should be so. +train-other-500/5019/38670/5019_38670_000017_000000|"Let me make you a present of the gloves," she said, with her irresistible smile. +train-other-500/5132/33409/5132_33409_000016_000001|They waited on the table in Valhalla. +train-other-500/52/121057/52_121057_000019_000000|"I," cried the steward with a strange expression. +train-other-500/5321/53046/5321_53046_000025_000003|I gather from what mrs joel said that she's rather touched in her mind too, and has an awful hankering to get home here-to this very house. +train-other-500/5429/210770/5429_210770_000029_000006|But this was not all. +train-other-500/557/129797/557_129797_000072_000001|The guns were manned, the gunners already kindling fuses, when the buccaneer fleet, whilst still heading for Palomas, was observed to bear away to the west. +train-other-500/572/128861/572_128861_000016_000002|My home was desolate. +train-other-500/5826/53497/5826_53497_000044_000001|If it be as you say, he will have shown himself noble, and his nobility will have consisted in this, that he has been willing to take that which he does not want, in order that he may succour one whom he loves. +train-other-500/5906/52158/5906_52158_000055_000000|The impression that he gets this knowledge or suspicion from the outside is due, the scientists say, to the fact that his thinking has proceeded at such lightning like speed that he was unable to watch the wheels go round. +train-other-500/6009/57639/6009_57639_000038_000000|This, friendly reader, is my only motive. +train-other-500/6106/58196/6106_58196_000007_000001|I tell you that you must make the dress. +train-other-500/6178/86034/6178_86034_000079_000004|Then she will grow calmer, and will know you again. +train-other-500/6284/63091/6284_63091_000133_000001|I don't want to go anywhere where anybody'll see me." +train-other-500/6436/104980/6436_104980_000009_000002|I guess you never heard about this house." +train-other-500/6540/232291/6540_232291_000017_000003|The girl was not wholly a savage. +train-other-500/6627/67844/6627_67844_000046_000002|The other girls had stopped talking, and now looked at Sylvia as if wondering what she would say. +train-other-500/6707/77351/6707_77351_000002_000006|But our first words I may give you, because though they conveyed nothing to me at the time, afterwards they meant much. +train-other-500/6777/76694/6777_76694_000013_000011|When they are forcibly put out of Garraway's on Saturday night-which they must be, for they never would go out of their own accord-where do they vanish until Monday morning? +train-other-500/690/133452/690_133452_000011_000000|Campany lifted his quill pen and pointed to a case of big leather bound volumes in a far corner of the room. +train-other-500/7008/34667/7008_34667_000032_000002|What had happened? +train-other-500/7131/92815/7131_92815_000039_000001|The cabman tried to pass to the left, but a heavy express wagon cut him off. +train-other-500/7220/77911/7220_77911_000005_000000|"Do? +train-other-500/7326/245693/7326_245693_000008_000000|Whether the Appetite Is a Special Power of the Soul? +train-other-500/7392/105672/7392_105672_000013_000005|Whoever, being required, refused to answer upon oath to any article of this act of settlement, was declared to be guilty of treason; and by this clause a species of political inquisition was established in the kingdom, as well as the accusations of treason multiplied to an unreasonable degree. +train-other-500/7512/98636/7512_98636_000017_000002|A man thus rarely makes provision for the future, and looks with scorn on foreign customs which seem to betoken a fear lest, in old age, ungrateful children may neglect their parents and cast them aside. +train-other-500/7654/258963/7654_258963_000007_000007|Egypt, for a time reduced to a semi desert condition, has only in the past century been restored to a certain extent by the use of new methods and a return to the old ones. +train-other-500/7769/99396/7769_99396_000020_000002|I had to go out once a day in search of food. +train-other-500/791/127519/791_127519_000086_000000|This was how it came about. +train-other-500/8042/113769/8042_113769_000021_000000|House the second. +train-other-500/8180/274725/8180_274725_000010_000000|"What fools men are in love matters," quoth Patty to herself-"at least most men!" with a thought backward to Mark's sensible choosing. +train-other-500/8291/282929/8291_282929_000031_000006|He's in a devil of a-Well, he needs the money, and I've got to get it for him. You know my word's good, Cooper." +train-other-500/8389/120181/8389_120181_000022_000000|"No," I answered. +train-other-500/8476/269293/8476_269293_000078_000001|Annie, in some wonder, went downstairs alone. +train-other-500/8675/295195/8675_295195_000004_000004|Everything had gone on prosperously with them, and they had reared many successive families of young Nutcrackers, who went forth to assume their places in the forest of life, and to reflect credit on their bringing up,--so that naturally enough they began to have a very easy way of considering themselves models of wisdom. +train-other-500/9000/282381/9000_282381_000016_000008|Bank facings seemed to indicate that the richest pay dirt lay at bed rock. +train-other-500/978/132494/978_132494_000017_000001|And what made you come at that very minute? diff --git a/FlashSR/BigVGAN/README.md b/FlashSR/BigVGAN/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a6cff37786a486deb55bc070254027aa492c2e92 --- /dev/null +++ b/FlashSR/BigVGAN/README.md @@ -0,0 +1,95 @@ +## BigVGAN: A Universal Neural Vocoder with Large-Scale Training +#### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon + +
+ + +### [Paper](https://arxiv.org/abs/2206.04658) +### [Audio demo](https://bigvgan-demo.github.io/) + +## Installation +Clone the repository and install dependencies. +```shell +# the codebase has been tested on Python 3.8 / 3.10 with PyTorch 1.12.1 / 1.13 conda binaries +git clone https://github.com/NVIDIA/BigVGAN +pip install -r requirements.txt +``` + +Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset. +``` shell +cd LibriTTS && \ +ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \ +ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \ +ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \ +ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \ +ln -s /path/to/your/LibriTTS/dev-other dev-other && \ +ln -s /path/to/your/LibriTTS/test-clean test-clean && \ +ln -s /path/to/your/LibriTTS/test-other test-other && \ +cd .. +``` + +## Training +Train BigVGAN model. Below is an example command for training BigVGAN using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input. +```shell +python train.py \ +--config configs/bigvgan_24khz_100band.json \ +--input_wavs_dir LibriTTS \ +--input_training_file LibriTTS/train-full.txt \ +--input_validation_file LibriTTS/val-full.txt \ +--list_input_unseen_wavs_dir LibriTTS LibriTTS \ +--list_input_unseen_validation_file LibriTTS/dev-clean.txt LibriTTS/dev-other.txt \ +--checkpoint_path exp/bigvgan +``` + +## Synthesis +Synthesize from BigVGAN model. Below is an example command for generating audio from the model. +It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`. +```shell +python inference.py \ +--checkpoint_file exp/bigvgan/g_05000000 \ +--input_wavs_dir /path/to/your/input_wav \ +--output_dir /path/to/your/output_wav +``` + +`inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`. +It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`. + +Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model. +```shell +python inference_e2e.py \ +--checkpoint_file exp/bigvgan/g_05000000 \ +--input_mels_dir /path/to/your/input_mel \ +--output_dir /path/to/your/output_wav +``` + +## Pretrained Models +We provide the [pretrained models](https://drive.google.com/drive/folders/1e9wdM29d-t3EHUpBb8T4dcHrkYGAXTgq). +One can download the checkpoints of generator (e.g., g_05000000) and discriminator (e.g., do_05000000) within the listed folders. + +|Folder Name|Sampling Rate|Mel band|fmax|Params.|Dataset|Fine-Tuned| +|------|---|---|---|---|------|---| +|bigvgan_24khz_100band|24 kHz|100|12000|112M|LibriTTS|No| +|bigvgan_base_24khz_100band|24 kHz|100|12000|14M|LibriTTS|No| +|bigvgan_22khz_80band|22 kHz|80|8000|112M|LibriTTS + VCTK + LJSpeech|No| +|bigvgan_base_22khz_80band|22 kHz|80|8000|14M|LibriTTS + VCTK + LJSpeech|No| + +The paper results are based on 24kHz BigVGAN models trained on LibriTTS dataset. +We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications. +Note that, the latest checkpoints use ``snakebeta`` activation with log scale parameterization, which have the best overall quality. + + +## TODO + +Current codebase only provides a plain PyTorch implementation for the filtered nonlinearity. We are working on a fast CUDA kernel implementation, which will be released in the future. + + +## References +* [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator) + +* [Snake](https://github.com/EdwardDixon/snake) (for periodic activation) + +* [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing) + +* [Julius](https://github.com/adefossez/julius) (for low-pass filter) + +* [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator) \ No newline at end of file diff --git a/FlashSR/BigVGAN/activations.py b/FlashSR/BigVGAN/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..61f2808a5466b3cf4d041059700993af5527dd29 --- /dev/null +++ b/FlashSR/BigVGAN/activations.py @@ -0,0 +1,120 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x \ No newline at end of file diff --git a/FlashSR/BigVGAN/alias_free_torch/__init__.py b/FlashSR/BigVGAN/alias_free_torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a2318b63198250856809c0cb46210a4147b829bc --- /dev/null +++ b/FlashSR/BigVGAN/alias_free_torch/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .filter import * +from .resample import * +from .act import * \ No newline at end of file diff --git a/FlashSR/BigVGAN/alias_free_torch/act.py b/FlashSR/BigVGAN/alias_free_torch/act.py new file mode 100644 index 0000000000000000000000000000000000000000..028debd697dd60458aae75010057df038bd3518a --- /dev/null +++ b/FlashSR/BigVGAN/alias_free_torch/act.py @@ -0,0 +1,28 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x \ No newline at end of file diff --git a/FlashSR/BigVGAN/alias_free_torch/filter.py b/FlashSR/BigVGAN/alias_free_torch/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad6ea87c1f10ddd94c544037791d7a4634d5ae1 --- /dev/null +++ b/FlashSR/BigVGAN/alias_free_torch/filter.py @@ -0,0 +1,95 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + + return out \ No newline at end of file diff --git a/FlashSR/BigVGAN/alias_free_torch/resample.py b/FlashSR/BigVGAN/alias_free_torch/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..750e6c3402cc5ac939c4b9d075246562e0e1d1a7 --- /dev/null +++ b/FlashSR/BigVGAN/alias_free_torch/resample.py @@ -0,0 +1,49 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx \ No newline at end of file diff --git a/FlashSR/BigVGAN/configs/bigvgan_22khz_80band.json b/FlashSR/BigVGAN/configs/bigvgan_22khz_80band.json new file mode 100644 index 0000000000000000000000000000000000000000..9ebdf92a04cbee73b949dc2a5d367553c23d6115 --- /dev/null +++ b/FlashSR/BigVGAN/configs/bigvgan_22khz_80band.json @@ -0,0 +1,45 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "activation": "snakebeta", + "snake_logscale": true, + + "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/FlashSR/BigVGAN/configs/bigvgan_24khz_100band.json b/FlashSR/BigVGAN/configs/bigvgan_24khz_100band.json new file mode 100644 index 0000000000000000000000000000000000000000..d9988a91c46eb6423fbb3c89d4b00daa49021d74 --- /dev/null +++ b/FlashSR/BigVGAN/configs/bigvgan_24khz_100band.json @@ -0,0 +1,45 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "activation": "snakebeta", + "snake_logscale": true, + + "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "segment_size": 8192, + "num_mels": 100, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 24000, + + "fmin": 0, + "fmax": 12000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/FlashSR/BigVGAN/configs/bigvgan_base_22khz_80band.json b/FlashSR/BigVGAN/configs/bigvgan_base_22khz_80band.json new file mode 100644 index 0000000000000000000000000000000000000000..32979f5228e6a82201d43c0d8f82262a78e99146 --- /dev/null +++ b/FlashSR/BigVGAN/configs/bigvgan_base_22khz_80band.json @@ -0,0 +1,45 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + + "upsample_rates": [8,8,2,2], + "upsample_kernel_sizes": [16,16,4,4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "activation": "snakebeta", + "snake_logscale": true, + + "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/FlashSR/BigVGAN/configs/bigvgan_base_24khz_100band.json b/FlashSR/BigVGAN/configs/bigvgan_base_24khz_100band.json new file mode 100644 index 0000000000000000000000000000000000000000..889a77c2623579dcf0b71f1b74c5861ae9774515 --- /dev/null +++ b/FlashSR/BigVGAN/configs/bigvgan_base_24khz_100band.json @@ -0,0 +1,45 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + + "upsample_rates": [8,8,2,2], + "upsample_kernel_sizes": [16,16,4,4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "activation": "snakebeta", + "snake_logscale": true, + + "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "segment_size": 8192, + "num_mels": 100, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 24000, + + "fmin": 0, + "fmax": 12000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/FlashSR/BigVGAN/env.py b/FlashSR/BigVGAN/env.py new file mode 100644 index 0000000000000000000000000000000000000000..b8be238d4db710c8c9a338d336baea0138f18d1f --- /dev/null +++ b/FlashSR/BigVGAN/env.py @@ -0,0 +1,18 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) \ No newline at end of file diff --git a/FlashSR/BigVGAN/incl_licenses/LICENSE_1 b/FlashSR/BigVGAN/incl_licenses/LICENSE_1 new file mode 100644 index 0000000000000000000000000000000000000000..5afae394d6b37da0e12ba6b290d2512687f421ac --- /dev/null +++ b/FlashSR/BigVGAN/incl_licenses/LICENSE_1 @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Jungil Kong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/FlashSR/BigVGAN/incl_licenses/LICENSE_2 b/FlashSR/BigVGAN/incl_licenses/LICENSE_2 new file mode 100644 index 0000000000000000000000000000000000000000..322b758863c4219be68291ae3826218baa93cb4c --- /dev/null +++ b/FlashSR/BigVGAN/incl_licenses/LICENSE_2 @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Edward Dixon + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/FlashSR/BigVGAN/incl_licenses/LICENSE_3 b/FlashSR/BigVGAN/incl_licenses/LICENSE_3 new file mode 100644 index 0000000000000000000000000000000000000000..56ee3c8c4cc2b4b32e0975d17258f9ba515fdbcc --- /dev/null +++ b/FlashSR/BigVGAN/incl_licenses/LICENSE_3 @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/FlashSR/BigVGAN/incl_licenses/LICENSE_4 b/FlashSR/BigVGAN/incl_licenses/LICENSE_4 new file mode 100644 index 0000000000000000000000000000000000000000..48fd1a1ba8d81a94b6c7d1c2ff1a1f307cc5371d --- /dev/null +++ b/FlashSR/BigVGAN/incl_licenses/LICENSE_4 @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2019, Seungwon Park 박승원 +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/FlashSR/BigVGAN/incl_licenses/LICENSE_5 b/FlashSR/BigVGAN/incl_licenses/LICENSE_5 new file mode 100644 index 0000000000000000000000000000000000000000..01ae5538e6b7c787bb4f5d6f2cd9903520d6e465 --- /dev/null +++ b/FlashSR/BigVGAN/incl_licenses/LICENSE_5 @@ -0,0 +1,16 @@ +Copyright 2020 Alexandre Défossez + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +associated documentation files (the "Software"), to deal in the Software without restriction, +including without limitation the rights to use, copy, modify, merge, publish, distribute, +sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or +substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/FlashSR/BigVGAN/inference.py b/FlashSR/BigVGAN/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..0769fbbe5f3a6bb0c7b860fa3b7f003b82de8da2 --- /dev/null +++ b/FlashSR/BigVGAN/inference.py @@ -0,0 +1,104 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import glob +import os +import argparse +import json +import torch +from scipy.io.wavfile import write +from env import AttrDict +from meldataset import mel_spectrogram, MAX_WAV_VALUE +from models import BigVGAN as Generator +import librosa + +h = None +device = None +torch.backends.cudnn.benchmark = False + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def get_mel(x): + return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '*') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return '' + return sorted(cp_list)[-1] + + +def inference(a, h): + generator = Generator(h).to(device) + + state_dict_g = load_checkpoint(a.checkpoint_file, device) + generator.load_state_dict(state_dict_g['generator']) + + filelist = os.listdir(a.input_wavs_dir) + + os.makedirs(a.output_dir, exist_ok=True) + + generator.eval() + generator.remove_weight_norm() + with torch.no_grad(): + for i, filname in enumerate(filelist): + # load the ground truth audio and resample if necessary + wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), h.sampling_rate, mono=True) + wav = torch.FloatTensor(wav).to(device) + # compute mel spectrogram from the ground truth audio + x = get_mel(wav.unsqueeze(0)) + + y_g_hat = generator(x) + + audio = y_g_hat.squeeze() + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype('int16') + + output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated.wav') + write(output_file, h.sampling_rate, audio) + print(output_file) + + +def main(): + print('Initializing Inference Process..') + + parser = argparse.ArgumentParser() + parser.add_argument('--input_wavs_dir', default='test_files') + parser.add_argument('--output_dir', default='generated_files') + parser.add_argument('--checkpoint_file', required=True) + + a = parser.parse_args() + + config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json') + with open(config_file) as f: + data = f.read() + + global h + json_config = json.loads(data) + h = AttrDict(json_config) + + torch.manual_seed(h.seed) + global device + if torch.cuda.is_available(): + torch.cuda.manual_seed(h.seed) + device = torch.device('cuda') + else: + device = torch.device('cpu') + + inference(a, h) + + +if __name__ == '__main__': + main() + diff --git a/FlashSR/BigVGAN/inference_e2e.py b/FlashSR/BigVGAN/inference_e2e.py new file mode 100644 index 0000000000000000000000000000000000000000..9d2ad6080c0498514d64a9243778edf525f77854 --- /dev/null +++ b/FlashSR/BigVGAN/inference_e2e.py @@ -0,0 +1,100 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import glob +import os +import numpy as np +import argparse +import json +import torch +from scipy.io.wavfile import write +from env import AttrDict +from meldataset import MAX_WAV_VALUE +from models import BigVGAN as Generator + +h = None +device = None +torch.backends.cudnn.benchmark = False + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '*') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return '' + return sorted(cp_list)[-1] + + +def inference(a, h): + generator = Generator(h).to(device) + + state_dict_g = load_checkpoint(a.checkpoint_file, device) + generator.load_state_dict(state_dict_g['generator']) + + filelist = os.listdir(a.input_mels_dir) + + os.makedirs(a.output_dir, exist_ok=True) + + generator.eval() + generator.remove_weight_norm() + with torch.no_grad(): + for i, filname in enumerate(filelist): + # load the mel spectrogram in .npy format + x = np.load(os.path.join(a.input_mels_dir, filname)) + x = torch.FloatTensor(x).to(device) + if len(x.shape) == 2: + x = x.unsqueeze(0) + + y_g_hat = generator(x) + + audio = y_g_hat.squeeze() + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype('int16') + + output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated_e2e.wav') + write(output_file, h.sampling_rate, audio) + print(output_file) + + +def main(): + print('Initializing Inference Process..') + + parser = argparse.ArgumentParser() + parser.add_argument('--input_mels_dir', default='test_mel_files') + parser.add_argument('--output_dir', default='generated_files_from_mel') + parser.add_argument('--checkpoint_file', required=True) + + a = parser.parse_args() + + config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json') + with open(config_file) as f: + data = f.read() + + global h + json_config = json.loads(data) + h = AttrDict(json_config) + + torch.manual_seed(h.seed) + global device + if torch.cuda.is_available(): + torch.cuda.manual_seed(h.seed) + device = torch.device('cuda') + else: + device = torch.device('cpu') + + inference(a, h) + + +if __name__ == '__main__': + main() + diff --git a/FlashSR/BigVGAN/meldataset.py b/FlashSR/BigVGAN/meldataset.py new file mode 100644 index 0000000000000000000000000000000000000000..306f301802ddf1c87fd5f9c55dc125e41c45cbcd --- /dev/null +++ b/FlashSR/BigVGAN/meldataset.py @@ -0,0 +1,212 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import math +import os +import random +import torch +import torch.utils.data +import numpy as np +from librosa.util import normalize +from scipy.io.wavfile import read +from librosa.filters import mel as librosa_mel_fn +import pathlib +from tqdm import tqdm + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path, sr_target): + sampling_rate, data = read(full_path) + if sampling_rate != sr_target: + raise RuntimeError("Sampling rate of the file {} is {} Hz, but the model requires {} Hz". + format(full_path, sampling_rate, sr_target)) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + # complex tensor as default, then use view_as_real for future pytorch compatibility + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + + spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def get_dataset_filelist(a): + with open(a.input_training_file, 'r', encoding='utf-8') as fi: + training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') + for x in fi.read().split('\n') if len(x) > 0] + print("first training file: {}".format(training_files[0])) + + with open(a.input_validation_file, 'r', encoding='utf-8') as fi: + validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') + for x in fi.read().split('\n') if len(x) > 0] + print("first validation file: {}".format(validation_files[0])) + + list_unseen_validation_files = [] + for i in range(len(a.list_input_unseen_validation_file)): + with open(a.list_input_unseen_validation_file[i], 'r', encoding='utf-8') as fi: + unseen_validation_files = [os.path.join(a.list_input_unseen_wavs_dir[i], x.split('|')[0] + '.wav') + for x in fi.read().split('\n') if len(x) > 0] + print("first unseen {}th validation fileset: {}".format(i, unseen_validation_files[0])) + list_unseen_validation_files.append(unseen_validation_files) + + return training_files, validation_files, list_unseen_validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__(self, training_files, hparams, segment_size, n_fft, num_mels, + hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1, + device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None, is_seen=True): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.hparams = hparams + self.is_seen = is_seen + if self.is_seen: + self.name = pathlib.Path(self.audio_files[0]).parts[0] + else: + self.name = '-'.join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/") + + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.cached_wav = None + self.n_cache_reuse = n_cache_reuse + self._cache_ref_count = 0 + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + print("INFO: checking dataset integrity...") + for i in tqdm(range(len(self.audio_files))): + assert os.path.exists(self.audio_files[i]), "{} not found".format(self.audio_files[i]) + + def __getitem__(self, index): + + filename = self.audio_files[index] + if self._cache_ref_count == 0: + audio, sampling_rate = load_wav(filename, self.sampling_rate) + audio = audio / MAX_WAV_VALUE + if not self.fine_tuning: + audio = normalize(audio) * 0.95 + self.cached_wav = audio + if sampling_rate != self.sampling_rate: + raise ValueError("{} SR doesn't match target {} SR".format( + sampling_rate, self.sampling_rate)) + self._cache_ref_count = self.n_cache_reuse + else: + audio = self.cached_wav + self._cache_ref_count -= 1 + + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) + + if not self.fine_tuning: + if self.split: + if audio.size(1) >= self.segment_size: + max_audio_start = audio.size(1) - self.segment_size + audio_start = random.randint(0, max_audio_start) + audio = audio[:, audio_start:audio_start+self.segment_size] + else: + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') + + mel = mel_spectrogram(audio, self.n_fft, self.num_mels, + self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, + center=False) + else: # validation step + # match audio length to self.hop_size * n for evaluation + if (audio.size(1) % self.hop_size) != 0: + audio = audio[:, :-(audio.size(1) % self.hop_size)] + mel = mel_spectrogram(audio, self.n_fft, self.num_mels, + self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, + center=False) + assert audio.shape[1] == mel.shape[2] * self.hop_size, "audio shape {} mel shape {}".format(audio.shape, mel.shape) + + else: + mel = np.load( + os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy')) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start:mel_start + frames_per_seg] + audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size] + else: + mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant') + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') + + mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels, + self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss, + center=False) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + def __len__(self): + return len(self.audio_files) diff --git a/FlashSR/BigVGAN/models.py b/FlashSR/BigVGAN/models.py new file mode 100644 index 0000000000000000000000000000000000000000..5026ded81ee5265efd2cca05314acc4988c2ac64 --- /dev/null +++ b/FlashSR/BigVGAN/models.py @@ -0,0 +1,388 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. +from TorchJaekwon.Util.Util import Util +from TorchJaekwon.Util.UtilData import UtilData +from easydict import EasyDict +Util.set_sys_path_to_parent_dir(__file__, depth_to_dir_from_file=2) + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +import Model.BigVGAN.activations as activations +from Model.BigVGAN.utils import init_weights, get_padding +from Model.BigVGAN.alias_free_torch import * + +LRELU_SLOPE = 0.1 + + +class AMPBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): + super(AMPBlock1, self).__init__() + self.h = h + + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class AMPBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None): + super(AMPBlock2, self).__init__() + self.h = h + + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) # total number of conv layers + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") + + def forward(self, x): + for c, a in zip (self.convs, self.activations): + xt = a(x) + xt = c(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class BigVGAN(torch.nn.Module): + # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. + def __init__(self, h): + super(BigVGAN, self).__init__() + self.h = h + + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + + # pre conv + self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) + + # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2 + + # transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append(nn.ModuleList([ + weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, u, padding=(k - u) // 2)) + ])) + + # residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d, activation=h.activation)) + + # post conv + if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing + activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing + activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + else: + raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + + # weight initialization + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + # pre conv + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + # upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # post conv + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + for l_i in l: + remove_weight_norm(l_i) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, h, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.d_mult = h.discriminator_channel_mult + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, int(32*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(32*self.d_mult), int(128*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(128*self.d_mult), int(512*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(512*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(1024*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(int(1024*self.d_mult), 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, h): + super(MultiPeriodDiscriminator, self).__init__() + self.mpd_reshapes = h.mpd_reshapes + print("mpd_reshapes: {}".format(self.mpd_reshapes)) + discriminators = [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes] + self.discriminators = nn.ModuleList(discriminators) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__(self, cfg, resolution): + super().__init__() + + self.resolution = resolution + assert len(self.resolution) == 3, \ + "MRD layer requires list with len=3, got {}".format(self.resolution) + self.lrelu_slope = LRELU_SLOPE + + norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm + if hasattr(cfg, "mrd_use_spectral_norm"): + print("INFO: overriding MRD use_spectral_norm as {}".format(cfg.mrd_use_spectral_norm)) + norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm + self.d_mult = cfg.discriminator_channel_mult + if hasattr(cfg, "mrd_channel_mult"): + print("INFO: overriding mrd channel multiplier as {}".format(cfg.mrd_channel_mult)) + self.d_mult = cfg.mrd_channel_mult + + self.convs = nn.ModuleList([ + norm_f(nn.Conv2d(1, int(32*self.d_mult), (3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 3), padding=(1, 1))), + ]) + self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))) + + def forward(self, x): + fmap = [] + + x = self.spectrogram(x) + x = x.unsqueeze(1) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + def spectrogram(self, x): + n_fft, hop_length, win_length = self.resolution + x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect') + x = x.squeeze(1) + x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True) + x = torch.view_as_real(x) # [B, F, TT, 2] + mag = torch.norm(x, p=2, dim =-1) #[B, F, TT] + + return mag + + +class MultiResolutionDiscriminator(nn.Module): + def __init__(self, cfg, debug=False): + super().__init__() + self.resolutions = cfg.resolutions + assert len(self.resolutions) == 3,\ + "MRD requires list of list with len=3, each element having a list with len=3. got {}".\ + format(self.resolutions) + self.discriminators = nn.ModuleList( + [DiscriminatorR(cfg, resolution) for resolution in self.resolutions] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(x=y) + y_d_g, fmap_g = d(x=y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss*2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1-dr)**2) + g_loss = torch.mean(dg**2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1-dg)**2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + +if __name__ == '__main__': + big_v_gan = BigVGAN(EasyDict(UtilData.json_load('Model/BigVGAN/configs/bigvgan_24khz_100band.json'))) + big_v_gan(torch.randn(4, 100, 282)) + print('') \ No newline at end of file diff --git a/FlashSR/BigVGAN/parse_scripts/parse_libritts.py b/FlashSR/BigVGAN/parse_scripts/parse_libritts.py new file mode 100644 index 0000000000000000000000000000000000000000..0886ed027c5c3e52e3711abca5af122c87f0ca08 --- /dev/null +++ b/FlashSR/BigVGAN/parse_scripts/parse_libritts.py @@ -0,0 +1,62 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import os, glob + +def get_wav_and_text_filelist(data_root, data_type, subsample=1): + wav_list = sorted([path.replace(data_root, "")[1:] for path in glob.glob(os.path.join(data_root, data_type, "**/**/*.wav"))]) + wav_list = wav_list[::subsample] + txt_filelist = [path.replace('.wav', '.normalized.txt') for path in wav_list] + + txt_list = [] + for txt_file in txt_filelist: + with open(os.path.join(data_root, txt_file), 'r') as f_txt: + text = f_txt.readline().strip('\n') + txt_list.append(text) + wav_list = [path.replace('.wav', '') for path in wav_list] + + return wav_list, txt_list + +def write_filelist(output_path, wav_list, txt_list): + with open(output_path, 'w') as f: + for i in range(len(wav_list)): + filename = wav_list[i] + '|' + txt_list[i] + f.write(filename + '\n') + +if __name__ == "__main__": + + data_root = "LibriTTS" + + # dev and test sets. subsample each sets to get ~100 utterances + data_type_list = ["dev-clean", "dev-other", "test-clean", "test-other"] + subsample_list = [50, 50, 50, 50] + for (data_type, subsample) in zip(data_type_list, subsample_list): + print("processing {}".format(data_type)) + data_path = os.path.join(data_root, data_type) + assert os.path.exists(data_path),\ + "path {} not found. make sure the path is accessible by creating the symbolic link using the following command: "\ + "ln -s /path/to/your/{} {}".format(data_path, data_path, data_path) + wav_list, txt_list = get_wav_and_text_filelist(data_root, data_type, subsample) + write_filelist(os.path.join(data_root, data_type+".txt"), wav_list, txt_list) + + # training and seen speaker validation datasets (libritts-full): train-clean-100 + train-clean-360 + train-other-500 + wav_list_train, txt_list_train = [], [] + for data_type in ["train-clean-100", "train-clean-360", "train-other-500"]: + print("processing {}".format(data_type)) + data_path = os.path.join(data_root, data_type) + assert os.path.exists(data_path),\ + "path {} not found. make sure the path is accessible by creating the symbolic link using the following command: "\ + "ln -s /path/to/your/{} {}".format(data_path, data_path, data_path) + wav_list, txt_list = get_wav_and_text_filelist(data_root, data_type) + wav_list_train.extend(wav_list) + txt_list_train.extend(txt_list) + + # split the training set so that the seen speaker validation set contains ~100 utterances + subsample_val = 3000 + wav_list_val, txt_list_val = wav_list_train[::subsample_val], txt_list_train[::subsample_val] + del wav_list_train[::subsample_val] + del txt_list_train[::subsample_val] + write_filelist(os.path.join(data_root, "train-full.txt"), wav_list_train, txt_list_train) + write_filelist(os.path.join(data_root, "val-full.txt"), wav_list_val, txt_list_val) + + print("done") \ No newline at end of file diff --git a/FlashSR/BigVGAN/requirements.txt b/FlashSR/BigVGAN/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2bad038ca517fef1bae14d0d5c7fe078696c23ce --- /dev/null +++ b/FlashSR/BigVGAN/requirements.txt @@ -0,0 +1,10 @@ +torch +numpy +librosa==0.8.1 +scipy +tensorboard +soundfile +matplotlib +pesq +auraloss +tqdm \ No newline at end of file diff --git a/FlashSR/BigVGAN/train.py b/FlashSR/BigVGAN/train.py new file mode 100644 index 0000000000000000000000000000000000000000..cadee69a5d9757c487d1ef4ca87df85c0b3e5470 --- /dev/null +++ b/FlashSR/BigVGAN/train.py @@ -0,0 +1,445 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + + +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) +import itertools +import os +import time +import argparse +import json +import torch +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DistributedSampler, DataLoader +import torch.multiprocessing as mp +from torch.distributed import init_process_group +from torch.nn.parallel import DistributedDataParallel +from env import AttrDict, build_env +from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist, MAX_WAV_VALUE +from models import BigVGAN, MultiPeriodDiscriminator, MultiResolutionDiscriminator,\ + feature_loss, generator_loss, discriminator_loss +from utils import plot_spectrogram, plot_spectrogram_clipped, scan_checkpoint, load_checkpoint, save_checkpoint, save_audio +import torchaudio as ta +from pesq import pesq +from tqdm import tqdm +import auraloss + +torch.backends.cudnn.benchmark = False + +def train(rank, a, h): + if h.num_gpus > 1: + # initialize distributed + init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], + world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) + + # set seed and device + torch.cuda.manual_seed(h.seed) + torch.cuda.set_device(rank) + device = torch.device('cuda:{:d}'.format(rank)) + + # define BigVGAN generator + generator = BigVGAN(h).to(device) + print("Generator params: {}".format(sum(p.numel() for p in generator.parameters()))) + + # define discriminators. MPD is used by default + mpd = MultiPeriodDiscriminator(h).to(device) + print("Discriminator mpd params: {}".format(sum(p.numel() for p in mpd.parameters()))) + + # define additional discriminators. BigVGAN uses MRD as default + mrd = MultiResolutionDiscriminator(h).to(device) + print("Discriminator mrd params: {}".format(sum(p.numel() for p in mrd.parameters()))) + + # create or scan the latest checkpoint from checkpoints directory + if rank == 0: + print(generator) + os.makedirs(a.checkpoint_path, exist_ok=True) + print("checkpoints directory : ", a.checkpoint_path) + + if os.path.isdir(a.checkpoint_path): + cp_g = scan_checkpoint(a.checkpoint_path, 'g_') + cp_do = scan_checkpoint(a.checkpoint_path, 'do_') + + # load the latest checkpoint if exists + steps = 0 + if cp_g is None or cp_do is None: + state_dict_do = None + last_epoch = -1 + else: + state_dict_g = load_checkpoint(cp_g, device) + state_dict_do = load_checkpoint(cp_do, device) + generator.load_state_dict(state_dict_g['generator']) + mpd.load_state_dict(state_dict_do['mpd']) + mrd.load_state_dict(state_dict_do['mrd']) + steps = state_dict_do['steps'] + 1 + last_epoch = state_dict_do['epoch'] + + # initialize DDP, optimizers, and schedulers + if h.num_gpus > 1: + generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) + mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) + mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device) + + optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) + optim_d = torch.optim.AdamW(itertools.chain(mrd.parameters(), mpd.parameters()), + h.learning_rate, betas=[h.adam_b1, h.adam_b2]) + + if state_dict_do is not None: + optim_g.load_state_dict(state_dict_do['optim_g']) + optim_d.load_state_dict(state_dict_do['optim_d']) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) + + # define training and validation datasets + # unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset + # example: trained on LibriTTS, validate on VCTK + training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a) + + trainset = MelDataset(training_filelist, h, h.segment_size, h.n_fft, h.num_mels, + h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, + shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, + fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir, is_seen=True) + + train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None + + train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, + sampler=train_sampler, + batch_size=h.batch_size, + pin_memory=True, + drop_last=True) + + if rank == 0: + validset = MelDataset(validation_filelist, h, h.segment_size, h.n_fft, h.num_mels, + h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, + fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, + base_mels_path=a.input_mels_dir, is_seen=True) + validation_loader = DataLoader(validset, num_workers=1, shuffle=False, + sampler=None, + batch_size=1, + pin_memory=True, + drop_last=True) + + list_unseen_validset = [] + list_unseen_validation_loader = [] + for i in range(len(list_unseen_validation_filelist)): + unseen_validset = MelDataset(list_unseen_validation_filelist[i], h, h.segment_size, h.n_fft, h.num_mels, + h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, + fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, + base_mels_path=a.input_mels_dir, is_seen=False) + unseen_validation_loader = DataLoader(unseen_validset, num_workers=1, shuffle=False, + sampler=None, + batch_size=1, + pin_memory=True, + drop_last=True) + list_unseen_validset.append(unseen_validset) + list_unseen_validation_loader.append(unseen_validation_loader) + + # Tensorboard logger + sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) + if a.save_audio: # also save audio to disk if --save_audio is set to True + os.makedirs(os.path.join(a.checkpoint_path, 'samples'), exist_ok=True) + + # validation loop + # "mode" parameter is automatically defined as (seen or unseen)_(name of the dataset) + # if the name of the dataset contains "nonspeech", it skips PESQ calculation to prevent errors + def validate(rank, a, h, loader, mode="seen"): + assert rank == 0, "validate should only run on rank=0" + generator.eval() + torch.cuda.empty_cache() + + val_err_tot = 0 + val_pesq_tot = 0 + val_mrstft_tot = 0 + + # modules for evaluation metrics + pesq_resampler = ta.transforms.Resample(h.sampling_rate, 16000).cuda() + loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda") + + if a.save_audio: # also save audio to disk if --save_audio is set to True + os.makedirs(os.path.join(a.checkpoint_path, 'samples', 'gt_{}'.format(mode)), exist_ok=True) + os.makedirs(os.path.join(a.checkpoint_path, 'samples', '{}_{:08d}'.format(mode, steps)), exist_ok=True) + + with torch.no_grad(): + print("step {} {} speaker validation...".format(steps, mode)) + + # loop over validation set and compute metrics + for j, batch in tqdm(enumerate(loader)): + x, y, _, y_mel = batch + y = y.to(device) + if hasattr(generator, 'module'): + y_g_hat = generator.module(x.to(device)) + else: + y_g_hat = generator(x.to(device)) + y_mel = y_mel.to(device, non_blocking=True) + y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, + h.hop_size, h.win_size, + h.fmin, h.fmax_for_loss) + val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() + + # PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out) + if not "nonspeech" in mode: # skips if the name of dataset (in mode string) contains "nonspeech" + # resample to 16000 for pesq + y_16k = pesq_resampler(y) + y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1)) + y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() + y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() + val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, 'wb') + + # MRSTFT calculation + val_mrstft_tot += loss_mrstft(y_g_hat.squeeze(1), y).item() + + # log audio and figures to Tensorboard + if j % a.eval_subsample == 0: # subsample every nth from validation set + if steps >= 0: + sw.add_audio('gt_{}/y_{}'.format(mode, j), y[0], steps, h.sampling_rate) + if a.save_audio: # also save audio to disk if --save_audio is set to True + save_audio(y[0], os.path.join(a.checkpoint_path, 'samples', 'gt_{}'.format(mode), '{:04d}.wav'.format(j)), h.sampling_rate) + sw.add_figure('gt_{}/y_spec_{}'.format(mode, j), plot_spectrogram(x[0]), steps) + + sw.add_audio('generated_{}/y_hat_{}'.format(mode, j), y_g_hat[0], steps, h.sampling_rate) + if a.save_audio: # also save audio to disk if --save_audio is set to True + save_audio(y_g_hat[0, 0], os.path.join(a.checkpoint_path, 'samples', '{}_{:08d}'.format(mode, steps), '{:04d}.wav'.format(j)), h.sampling_rate) + # spectrogram of synthesized audio + y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, + h.sampling_rate, h.hop_size, h.win_size, + h.fmin, h.fmax) + sw.add_figure('generated_{}/y_hat_spec_{}'.format(mode, j), + plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps) + # visualization of spectrogram difference between GT and synthesized audio + # difference higher than 1 is clipped for better visualization + spec_delta = torch.clamp(torch.abs(x[0] - y_hat_spec.squeeze(0).cpu()), min=1e-6, max=1.) + sw.add_figure('delta_dclip1_{}/spec_{}'.format(mode, j), + plot_spectrogram_clipped(spec_delta.numpy(), clip_max=1.), steps) + + val_err = val_err_tot / (j + 1) + val_pesq = val_pesq_tot / (j + 1) + val_mrstft = val_mrstft_tot / (j + 1) + # log evaluation metrics to Tensorboard + sw.add_scalar("validation_{}/mel_spec_error".format(mode), val_err, steps) + sw.add_scalar("validation_{}/pesq".format(mode), val_pesq, steps) + sw.add_scalar("validation_{}/mrstft".format(mode), val_mrstft, steps) + + generator.train() + + # if the checkpoint is loaded, start with validation loop + if steps != 0 and rank == 0 and not a.debug: + if not a.skip_seen: + validate(rank, a, h, validation_loader, + mode="seen_{}".format(train_loader.dataset.name)) + for i in range(len(list_unseen_validation_loader)): + validate(rank, a, h, list_unseen_validation_loader[i], + mode="unseen_{}".format(list_unseen_validation_loader[i].dataset.name)) + # exit the script if --evaluate is set to True + if a.evaluate: + exit() + + # main training loop + generator.train() + mpd.train() + mrd.train() + for epoch in range(max(0, last_epoch), a.training_epochs): + if rank == 0: + start = time.time() + print("Epoch: {}".format(epoch+1)) + + if h.num_gpus > 1: + train_sampler.set_epoch(epoch) + + for i, batch in enumerate(train_loader): + if rank == 0: + start_b = time.time() + x, y, _, y_mel = batch + + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + y_mel = y_mel.to(device, non_blocking=True) + y = y.unsqueeze(1) + + y_g_hat = generator(x) + y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, + h.fmin, h.fmax_for_loss) + + optim_d.zero_grad() + + # MPD + y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) + loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) + + # MRD + y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach()) + loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) + + loss_disc_all = loss_disc_s + loss_disc_f + + # whether to freeze D for initial training steps + if steps >= a.freeze_step: + loss_disc_all.backward() + grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), 1000.) + grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), 1000.) + optim_d.step() + else: + print("WARNING: skipping D training for the first {} steps".format(a.freeze_step)) + grad_norm_mpd = 0. + grad_norm_mrd = 0. + + # generator + optim_g.zero_grad() + + # L1 Mel-Spectrogram Loss + loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 + + # MPD loss + y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) + loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) + loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) + + # MRD loss + y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(y, y_g_hat) + loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) + loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) + + if steps >= a.freeze_step: + loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel + else: + print("WARNING: using regression loss only for G for the first {} steps".format(a.freeze_step)) + loss_gen_all = loss_mel + + loss_gen_all.backward() + grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), 1000.) + optim_g.step() + + if rank == 0: + # STDOUT logging + if steps % a.stdout_interval == 0: + with torch.no_grad(): + mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() + + print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'. + format(steps, loss_gen_all, mel_error, time.time() - start_b)) + + # checkpointing + if steps % a.checkpoint_interval == 0 and steps != 0: + checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps) + save_checkpoint(checkpoint_path, + {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) + checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps) + save_checkpoint(checkpoint_path, + {'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(), + 'mrd': (mrd.module if h.num_gpus > 1 else mrd).state_dict(), + 'optim_g': optim_g.state_dict(), + 'optim_d': optim_d.state_dict(), + 'steps': steps, + 'epoch': epoch}) + + # Tensorboard summary logging + if steps % a.summary_interval == 0: + sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) + sw.add_scalar("training/mel_spec_error", mel_error, steps) + sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps) + sw.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps) + sw.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps) + sw.add_scalar("training/grad_norm_mpd", grad_norm_mpd, steps) + sw.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps) + sw.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps) + sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps) + sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps) + sw.add_scalar("training/grad_norm_g", grad_norm_g, steps) + sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps) + sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps) + sw.add_scalar("training/epoch", epoch+1, steps) + + # validation + if steps % a.validation_interval == 0: + # plot training input x so far used + for i_x in range(x.shape[0]): + sw.add_figure('training_input/x_{}'.format(i_x), plot_spectrogram(x[i_x].cpu()), steps) + sw.add_audio('training_input/y_{}'.format(i_x), y[i_x][0], steps, h.sampling_rate) + + # seen and unseen speakers validation loops + if not a.debug and steps != 0: + validate(rank, a, h, validation_loader, + mode="seen_{}".format(train_loader.dataset.name)) + for i in range(len(list_unseen_validation_loader)): + validate(rank, a, h, list_unseen_validation_loader[i], + mode="unseen_{}".format(list_unseen_validation_loader[i].dataset.name)) + steps += 1 + + scheduler_g.step() + scheduler_d.step() + + if rank == 0: + print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) + + +def main(): + print('Initializing Training Process..') + + parser = argparse.ArgumentParser() + + parser.add_argument('--group_name', default=None) + + parser.add_argument('--input_wavs_dir', default='LibriTTS') + parser.add_argument('--input_mels_dir', default='ft_dataset') + parser.add_argument('--input_training_file', default='LibriTTS/train-full.txt') + parser.add_argument('--input_validation_file', default='LibriTTS/val-full.txt') + + parser.add_argument('--list_input_unseen_wavs_dir', nargs='+', default=['LibriTTS', 'LibriTTS']) + parser.add_argument('--list_input_unseen_validation_file', nargs='+', default=['LibriTTS/dev-clean.txt', 'LibriTTS/dev-other.txt']) + + parser.add_argument('--checkpoint_path', default='exp/bigvgan') + parser.add_argument('--config', default='') + + parser.add_argument('--training_epochs', default=100000, type=int) + parser.add_argument('--stdout_interval', default=5, type=int) + parser.add_argument('--checkpoint_interval', default=50000, type=int) + parser.add_argument('--summary_interval', default=100, type=int) + parser.add_argument('--validation_interval', default=50000, type=int) + + parser.add_argument('--freeze_step', default=0, type=int, + help='freeze D for the first specified steps. G only uses regression loss for these steps.') + + parser.add_argument('--fine_tuning', default=False, type=bool) + + parser.add_argument('--debug', default=False, type=bool, + help="debug mode. skips validation loop throughout training") + parser.add_argument('--evaluate', default=False, type=bool, + help="only run evaluation from checkpoint and exit") + parser.add_argument('--eval_subsample', default=5, type=int, + help="subsampling during evaluation loop") + parser.add_argument('--skip_seen', default=False, type=bool, + help="skip seen dataset. useful for test set inference") + parser.add_argument('--save_audio', default=False, type=bool, + help="save audio of test set inference to disk") + + a = parser.parse_args() + + with open(a.config) as f: + data = f.read() + + json_config = json.loads(data) + h = AttrDict(json_config) + + build_env(a.config, 'config.json', a.checkpoint_path) + + torch.manual_seed(h.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(h.seed) + h.num_gpus = torch.cuda.device_count() + h.batch_size = int(h.batch_size / h.num_gpus) + print('Batch size per GPU :', h.batch_size) + else: + pass + + if h.num_gpus > 1: + mp.spawn(train, nprocs=h.num_gpus, args=(a, h,)) + else: + train(0, a, h) + + +if __name__ == '__main__': + main() diff --git a/FlashSR/BigVGAN/utils.py b/FlashSR/BigVGAN/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c05fed6fa80bcb2e222cc43bc6b57d61f99a6bab --- /dev/null +++ b/FlashSR/BigVGAN/utils.py @@ -0,0 +1,80 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import glob +import os +import matplotlib +import torch +from torch.nn.utils import weight_norm +matplotlib.use("Agg") +import matplotlib.pylab as plt +from FlashSR.BigVGAN.meldataset import MAX_WAV_VALUE +from scipy.io.wavfile import write + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def plot_spectrogram_clipped(spectrogram, clip_max=2.): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none', vmin=1e-6, vmax=clip_max) + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + +def save_audio(audio, path, sr): + # wav: torch with 1d shape + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype('int16') + write(path, sr, audio) \ No newline at end of file diff --git a/FlashSR/FlashSR.py b/FlashSR/FlashSR.py new file mode 100644 index 0000000000000000000000000000000000000000..578b5d16ff54312fcf730ee1d3ae3d81a94e4eaf --- /dev/null +++ b/FlashSR/FlashSR.py @@ -0,0 +1,103 @@ +from typing import Optional, Tuple, Union + +import torch + +from TorchJaekwon.Model.Diffusion.DDPM.DDPM import DDPM +from TorchJaekwon.Model.Diffusion.External.diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler +from TorchJaekwon.Model.Diffusion.External.diffusers.DiffusersWrapper import DiffusersWrapper + +from FlashSR.AudioSR.AudioSRUnet import AudioSRUnet +from FlashSR.VAEWrapper import VAEWrapper +from FlashSR.SRVocoder import SRVocoder +from FlashSR.Util.UtilAudioSR import UtilAudioSR +from FlashSR.Util.UtilAudioLowPassFilter import UtilAudioLowPassFilter + +class FlashSR(DDPM): + def __init__( + self, + student_ldm_ckpt_path:str, + sr_vocoder_ckpt_path:str, + autoencoder_ckpt_path:str, + model_output_type:str = 'v_prediction', + beta_schedule_type:str = 'cosine', + **kwargs + ) -> None: + + super().__init__(model = AudioSRUnet(), model_output_type=model_output_type, beta_schedule_type=beta_schedule_type, **kwargs) + + student_ldm_state_dict = torch.load(student_ldm_ckpt_path) + self.load_state_dict(student_ldm_state_dict) + + self.vae = VAEWrapper(autoencoder_ckpt_path) + + self.sr_vocoder = SRVocoder() + sr_vocoder_state_dict = torch.load(sr_vocoder_ckpt_path) + self.sr_vocoder.load_state_dict(sr_vocoder_state_dict) + + def forward(self, + lr_audio:torch.Tensor, #[batch, time] ex) [4, 245760] + num_steps:int = 1, + lowpass_input:bool = True, + lowpass_cutoff_freq:int = None + ) -> torch.Tensor: #[batch, time] ex) [4, 245760] + + if lowpass_input: + device = lr_audio.device + if lowpass_cutoff_freq is None: + lowpass_cutoff_freq:int = UtilAudioSR.find_cutoff_freq(lr_audio) + lr_audio = lr_audio.cpu().numpy() + lr_audio = UtilAudioLowPassFilter.lowpass(lr_audio, 48000, filter_name='cheby', filter_order=8, cutoff_freq=lowpass_cutoff_freq) + lr_audio = torch.from_numpy(lr_audio).to(device) + + with torch.no_grad(): + pred_hr_audio = DiffusersWrapper.infer( + ddpm_module=self, + diffusers_scheduler_class=DPMSolverMultistepScheduler, + x_shape=None, + cond = lr_audio, + num_steps=num_steps, + device=lr_audio.device + ) + pred_hr_audio = pred_hr_audio[...,:lr_audio.shape[-1]] + return pred_hr_audio + + def preprocess(self, + x_start:torch.Tensor, # [batch, time] + cond:Optional[Union[dict,torch.Tensor]] = None, # [batch, time] + ) -> Tuple[torch.Tensor, torch.Tensor]: #( [batch, 1 , mel, time//hop] , [batch, 1 , mel, time//hop] ) + device = cond.device + if self.vae.device != device: + self.vae.to(device=device) + x_dict = dict() + + cond_dict = self.vae.encode_to_z(cond) + + if x_start is not None: + state_dict:dict = { + 'mean_scale_factor': cond_dict['mean_scale_factor'], + 'var_scale_factor': cond_dict['var_scale_factor'] + } + x_dict = self.vae.encode_to_z(x_start, scale_dict=state_dict) ##[batch, 16, time / (hop * 8), mel_bin / 8] + + return x_dict.get('z', None), cond_dict['z'], cond_dict + + def postprocess(self, + x:torch.Tensor, #[batch, 1, mel, time] + additional_data_dict:dict) -> torch.Tensor: + mel_spec = self.vae.z_to_mel(x) + mel_spec = mel_spec.squeeze(1).transpose(1,2) + pred_hr_audio = self.sr_vocoder(mel_spec, additional_data_dict['norm_wav'])['pred_hr_audio'] + pred_hr_audio = self.vae.denormalize_wav(pred_hr_audio, additional_data_dict) + return pred_hr_audio + + def get_x_shape(self, cond): + return cond.shape + + def get_unconditional_condition(self, + cond:Optional[Union[dict,torch.Tensor]] = None, + cond_shape:Optional[tuple] = None, + condition_device:Optional[torch.device] = None + ) -> torch.Tensor: + if cond_shape is None: cond_shape = cond.shape + if cond is not None and isinstance(cond,torch.Tensor): condition_device = cond.device + return (-11.4981 + torch.zeros(cond_shape)).to(condition_device) * self.vae.scale_factor_z \ No newline at end of file diff --git a/FlashSR/SRVocoder.py b/FlashSR/SRVocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a94a53e249f4f959e71a8b34d2ec780c06f22c01 --- /dev/null +++ b/FlashSR/SRVocoder.py @@ -0,0 +1,212 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. +from torch import Tensor + +from TorchJaekwon.Util.Util import Util +from TorchJaekwon.Util.UtilData import UtilData +from TorchJaekwon.Util.UtilAudioMelSpec import UtilAudioMelSpec +#from easydict import EasyDict +#Util.set_sys_path_to_parent_dir(__file__, depth_to_dir_from_file=2) + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +import FlashSR.BigVGAN.activations as activations +from FlashSR.BigVGAN.utils import init_weights, get_padding +from FlashSR.BigVGAN.alias_free_torch import * + +LRELU_SLOPE = 0.1 + +class SRVocoder(torch.nn.Module): + def __init__(self, + num_mels = 256, + upsample_initial_channel = 1536, + resblock_kernel_sizes = [3, 7, 11], + resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates = [10, 6, 2, 2, 2], #[4, 4, 2, 2, 2, 2], upsample_rates = [5, 4, 3, 2, 2, 2], #[4, 4, 2, 2, 2, 2], + upsample_kernel_sizes = None, # upsample_kernel_sizes = [7,8,7,4,4,4], + activation = 'snakebeta', + snake_logscale = True + ): + super(SRVocoder, self).__init__() + if upsample_kernel_sizes is None: + upsample_kernel_sizes = [upsample_rate * 2 for upsample_rate in upsample_rates] + + self.audio_block = nn.ModuleDict() + self.audio_block["downsamples"] = nn.ModuleList() + self.audio_block["emb"] = Conv1d( 1, upsample_initial_channel // (2 ** len(upsample_rates)), 7, bias=True, padding=(7 - 1) // 2, ) + for i in reversed(range(len(upsample_kernel_sizes))): + self.audio_block["downsamples"] += [ + nn.Sequential( + nn.Conv1d( + upsample_initial_channel // (2 ** (i + 1)), + upsample_initial_channel // (2 ** i), + upsample_kernel_sizes[i], + upsample_rates[i], + padding=upsample_rates[i] - (upsample_kernel_sizes[i] % 2 == 0), + bias=True, + ), + nn.LeakyReLU(negative_slope = 0.1) + ) + ] + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + + # pre conv + self.conv_pre = weight_norm(Conv1d(num_mels, upsample_initial_channel, 7, 1, padding=3)) + + # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + resblock = AMPBlock1 + + # transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(nn.ModuleList([ + weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), + upsample_initial_channel // (2 ** (i + 1)), + k, u, padding=(k - u) // 2)) + ])) + + # residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d, activation=activation)) + + # post conv + if activation == "snake": # periodic nonlinearity with snake function and anti-aliasing + activation_post = activations.Snake(ch, alpha_logscale=snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + elif activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing + activation_post = activations.SnakeBeta(ch, alpha_logscale=snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + else: + raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + + # weight initialization + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + ''' + In audio sr + sampling_rate = 48000 + filter_length = 2048 + hop_length = 480 + win_length = 2048 + n_mel = 256 + mel_fmin = 20 + mel_fmax = 24000 + ''' + + def forward(self, + mel_spec:Tensor, #[batch, mel_size, time//hop] + lr_audio:Tensor, #[batch, time] + ) -> Tensor: #[batch, time] + + audio_emb:Tensor = self.audio_block["emb"](lr_audio.unsqueeze(1)) + audio_emb_list:list = [audio_emb] + for i in range(self.num_upsamples - 1): + audio_emb = self.audio_block["downsamples"][i](audio_emb) + audio_emb_list += [audio_emb] + + # pre conv + x = self.conv_pre(mel_spec) + + for i in range(self.num_upsamples): + # upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + audio_emb_list[-1-i] + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # post conv + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x).squeeze(1) + + return {'pred_hr_audio': x } + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + for l_i in l: + remove_weight_norm(l_i) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class AMPBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation=None, snake_logscale = 'snakebeta'): + super(AMPBlock1, self).__init__() + + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.Snake(channels, alpha_logscale=snake_logscale)) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=snake_logscale)) + for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) diff --git a/FlashSR/Util/UtilAudioLowPassFilter.py b/FlashSR/Util/UtilAudioLowPassFilter.py new file mode 100644 index 0000000000000000000000000000000000000000..31a5ae14e0883ab4675bb5668cb3b27afb0b583d --- /dev/null +++ b/FlashSR/Util/UtilAudioLowPassFilter.py @@ -0,0 +1,109 @@ + +#from TorchJaekwon.Util.Util import Util +#Util.set_sys_path_to_parent_dir(__file__, 2) + +from typing import Literal +from numpy import ndarray + +import numpy as np +from scipy.signal import butter, cheby1, cheby2, ellip, bessel, sosfiltfilt, resample_poly + +class UtilAudioLowPassFilter: + # this code is refactored version of https://github.com/haoheliu/ssr_eval + + @staticmethod + def lowpass(audio:ndarray, #[time] 1d array + sr:int, + filter_name:Literal["cheby","butter","bessel","ellip"], + filter_order:int, + cutoff_freq:int, + upsample_to_original:bool = True + ): + assert len(audio.shape) == 1 or (len(audio.shape) == 2 and (audio.shape[0] == 1 or audio.shape[0] == 2)) + if filter_name == "cheby": filter_name = "cheby1" + assert filter_order >= 2 and filter_order <= 10, f"filter_order should be between 2 and 10, but {filter_order} is given" + if cutoff_freq == sr: cutoff_freq -= 1 + if len(audio.shape) == 2: + lowpassed_audio = np.zeros_like(audio) + for i in range(audio.shape[0]): + lowpassed_audio[i] = UtilAudioLowPassFilter.lowpass_filter( x=audio[i], highcutoff_freq=int(cutoff_freq), fs=sr, order=filter_order, ftype=filter_name, upsample_to_original = upsample_to_original) + else: + lowpassed_audio = UtilAudioLowPassFilter.lowpass_filter( x=audio, highcutoff_freq=int(cutoff_freq), fs=sr, order=filter_order, ftype=filter_name, upsample_to_original = upsample_to_original) + if upsample_to_original: + assert lowpassed_audio.shape == audio.shape, f'error lowpass_butterworth: {str((lowpassed_audio.shape, audio.shape))}' + return lowpassed_audio.copy() # avoid the problem [Torch.from_numpy not support negative strides] + + + + @staticmethod + def lowpass_filter(x:ndarray, #[time] 1d array + highcutoff_freq:float, #high cutoff frequency + fs:int, + order:int, #the order of filter + ftype:Literal['butter', 'cheby1', 'cheby2', 'ellip', 'bessel'], + upsample_to_original:bool = True + ) -> ndarray: #[time] 1d array + nyq = 0.5 * fs + hi = highcutoff_freq / nyq + if ftype == "butter": + sos = butter(order, hi, btype="low", output="sos") + elif ftype == "cheby1": + sos = cheby1(order, 0.1, hi, btype="low", output="sos") + elif ftype == "cheby2": + sos = cheby2(order, 60, hi, btype="low", output="sos") + elif ftype == "ellip": + sos = ellip(order, 0.1, 60, hi, btype="low", output="sos") + elif ftype == "bessel": + sos = bessel(order, hi, btype="low", output="sos") + else: + raise Exception(f"The lowpass filter {ftype} is not supported!") + + y = sosfiltfilt(sos, x) + + if len(y) != len(x): + y = UtilAudioLowPassFilter.align_length(x, y) + # After low pass filtering. Resample the audio signal + y = UtilAudioLowPassFilter.subsampling(y, lowpass_ratio=highcutoff_freq / int(fs / 2), fs_ori=fs, upsample_to_original = upsample_to_original) + return y + + @staticmethod + def align_length(x, y): + """align the length of y to that of x + + Args: + x (np.array): reference signal + y (np.array): the signal needs to be length aligned + + Return: + yy (np.array): signal with the same length as x + """ + Lx = len(x) + Ly = len(y) + + if Lx == Ly: + return y + elif Lx > Ly: + # pad y with zeros + return np.pad(y, (0, Lx - Ly), mode="constant") + else: + # cut y + return y[:Lx] + + @staticmethod + def subsampling(data, lowpass_ratio, fs_ori=44100, upsample_to_original:bool = True): + assert len(data.shape) == 1 + fs_down = int(lowpass_ratio * fs_ori) + # downsample to the low sampling rate + y = resample_poly(data, fs_down, fs_ori) + + if upsample_to_original: + # upsample to the original sampling rate + y = resample_poly(y, fs_ori, fs_down) + + if len(y) != len(data): + y = UtilAudioLowPassFilter.align_length(data, y) + return y + +if __name__ == "__main__": + util = UtilAudioLowPassFilter() + util.lowpass(np.zeros(24000),48000, filter_name="cheby", filter_order=8, cutoff_freq=8000) \ No newline at end of file diff --git a/FlashSR/Util/UtilAudioSR.py b/FlashSR/Util/UtilAudioSR.py new file mode 100644 index 0000000000000000000000000000000000000000..6e003b460c7f57aec2852820a75da3542446c61d --- /dev/null +++ b/FlashSR/Util/UtilAudioSR.py @@ -0,0 +1,95 @@ +# source: https://github.com/haoheliu/versatile_audio_super_resolution +import librosa +import numpy as np +import torch + +from TorchJaekwon.Util.UtilData import UtilData +from TorchJaekwon.Util.UtilTorch import UtilTorch + +class UtilAudioSR: + + @staticmethod + def mel_replace_ops(predict_mel:torch.Tensor, #[batch, 1, time, melbin], log mel spectrogram + gt_low_pass_mel:torch.Tensor, #[batch, 1, time, melbin], log mel spectrogram + debug_message:bool=False + ) -> torch.Tensor: #[batch, 1, time, melbin] + batch_size = predict_mel.size(0) + for i in range(batch_size): + cutoff_melbin = UtilAudioSR.locate_cutoff_freq(torch.exp(gt_low_pass_mel[i].squeeze())) + + if debug_message: + ratio = predict_mel[i][...,:cutoff_melbin]/gt_low_pass_mel[i][...,:cutoff_melbin] + print(torch.mean(ratio), torch.max(ratio), torch.min(ratio)) + + predict_mel[i][..., :cutoff_melbin] = gt_low_pass_mel[i][..., :cutoff_melbin] + return predict_mel + + @staticmethod + def locate_cutoff_freq(stft, percentile=0.985): + magnitude = torch.abs(stft) + energy = torch.cumsum(torch.sum(magnitude, dim=0), dim=0) + return UtilAudioSR.find_cutoff(energy, percentile) + + @staticmethod + def find_cutoff(x, percentile=0.95): + percentile = x[-1] * percentile + for i in range(1, x.shape[0]): + if x[-i] < percentile: + return x.shape[0] - i + return 0 + + @staticmethod + def wav_replace_ops(pred_wav:torch.Tensor, #[batch, 1, time] + gt_low_pass_wav:torch.Tensor #[batch, 1, time] + ) -> torch.Tensor: #[batch, 1, time] + device = pred_wav.device + pred_wav = UtilData.fit_shape_length(pred_wav, 2).cpu().detach().numpy() + gt_low_pass_wav = UtilData.fit_shape_length(gt_low_pass_wav, 2).cpu().detach().numpy() + for i in range(pred_wav.shape[0]): + + out = pred_wav[i] + x = gt_low_pass_wav[i] + cutoffratio = UtilAudioSR.get_cutoff_index_np(x) + + length = out.shape[0] + stft_gt = librosa.stft(x) + + stft_out = librosa.stft(out) + energy_ratio = np.mean( + np.sum(np.abs(stft_gt[cutoffratio])) + / np.sum(np.abs(stft_out[cutoffratio, ...])) + ) + energy_ratio = min(max(energy_ratio, 0.8), 1.2) + stft_out[:cutoffratio, ...] = stft_gt[:cutoffratio, ...] / energy_ratio + + out_renewed = librosa.istft(stft_out, length=length) + pred_wav[i] = out_renewed + + return torch.FloatTensor(pred_wav).to(device) + + @staticmethod + def get_cutoff_index_np(x): + stft_x = np.abs(librosa.stft(x)) + energy = np.cumsum(np.sum(stft_x, axis=-1)) + return UtilAudioSR.find_cutoff(energy, 0.985) + + @staticmethod + def find_cutoff_freq(audio:torch.Tensor) -> int: + stft_spec = torch.stft( + input = audio, + n_fft = 2048, + hop_length=480, + win_length=2048, + window=torch.hann_window(2048).to(audio.device), + center=False, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + + stft_spec = stft_spec[0].T.float() + cutoff_freq = (UtilAudioSR.locate_cutoff_freq(stft_spec, percentile=0.983) / 1024) * 24000 + if(cutoff_freq < 1000): + cutoff_freq = 24000 + return cutoff_freq \ No newline at end of file diff --git a/FlashSR/VAEWrapper.py b/FlashSR/VAEWrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..1715119fd8128600a09fe50be35e366aba4006a6 --- /dev/null +++ b/FlashSR/VAEWrapper.py @@ -0,0 +1,125 @@ +from typing import Union, Optional +from numpy import ndarray +from torch import Tensor + +import os + +import torch +import torch.nn as nn + +from TorchJaekwon.Util.UtilData import UtilData +from TorchJaekwon.Util.UtilAudioMelSpec import UtilAudioMelSpec + +from FlashSR.AudioSR.autoencoder import AutoencoderKL +from TorchJaekwon.Util.UtilTorch import UtilTorch + +class VAEWrapper: + def __init__(self, + autoencoder_ckpt_path:str, + sr:int = '48000', + frame_sec:float = 5.12, + device:torch.device = torch.device('cpu'), + scale_factor_z:float = 0.3342 + ) -> None: + vocoder_config_dir:str = f'{os.path.dirname(os.path.abspath(__file__))}/AudioSR/args' + + self.sr:int = sr + self.frame_sec:float = frame_sec + self.scale_factor_z:float = scale_factor_z + self.device:torch.device = device + + autoencoder:nn.Module = AutoencoderKL(**UtilData.yaml_load(f'{vocoder_config_dir}/model_argument.yaml')) + autoencoder_ckpt = torch.load(autoencoder_ckpt_path, map_location='cpu') + autoencoder.load_state_dict(autoencoder_ckpt) + autoencoder = autoencoder.to(device) + self.autoencoder = UtilTorch.freeze_param(autoencoder) + + self.mel_config:dict = UtilData.yaml_load(f'{vocoder_config_dir}/mel_argument.yaml') + self.util_mel_spec = UtilAudioMelSpec(**self.mel_config) + + def to(self, device): + self.device = device + self.autoencoder = self.autoencoder.to(self.device) + + @torch.no_grad() + def encode_to_z(self, + audio:Union[ndarray,Tensor], # [batch, time] + normalize:bool = True, + scale_dict:dict = None, + ) -> dict: #[batch, 16, time / (hop * 8), mel_bin / 8] mel_bin: 256 + assert len(audio.shape) == 2, f'audio shape must be [batch, time] but got {audio.shape}' + result_dict:dict = {'wav': audio} + + if normalize: + audio, scale_dict = self.normalize_wav(audio, scale_dict=scale_dict) + result_dict['norm_wav'] = audio + result_dict.update(scale_dict) + + mel_spec:Tensor = self.audio_to_mel(audio) + result_dict['mel_spec'] = mel_spec + + encoder_posterior = self.autoencoder.encode(mel_spec) + z = encoder_posterior.sample() * self.scale_factor_z + + result_dict['z'] = z + return result_dict + + def normalize_wav(self, waveform:Union[Tensor], scale_dict:dict): + mean_scale_factor = torch.mean(waveform, dim=1, keepdim=True) if scale_dict is None else scale_dict['mean_scale_factor'] + waveform = waveform - mean_scale_factor + + var_scale_factor = torch.max(torch.abs(waveform), dim=1, keepdim=True)[0] if scale_dict is None else scale_dict['var_scale_factor'] + + waveform = waveform / (var_scale_factor + 1e-8) + return waveform * 0.5, {'mean_scale_factor':mean_scale_factor, 'var_scale_factor':var_scale_factor} + + def denormalize_wav(self, waveform:Union[Tensor], scale_dict:dict): + waveform = waveform * 2.0 + waveform = waveform * (scale_dict['var_scale_factor'] + 1e-8) + waveform = waveform + scale_dict['mean_scale_factor'] + return waveform + + def get_mel_spec(self, audio:Union[ndarray,Tensor]): + return self.util_mel_spec.get_hifigan_mel_spec(audio).to(self.device) + + @torch.no_grad() + def audio_to_mel(self, audio): + mel_spec:Tensor = self.util_mel_spec.get_hifigan_mel_spec(audio).to(self.device) + if len(mel_spec.shape) == 3: #to make [batch, channel, freq, time] + mel_spec = mel_spec.unsqueeze(1) + return mel_spec.permute(0, 1, 3, 2) + + def z_to_audio(self,z:Tensor, scale_dict:dict = None, with_no_grad:bool = True): + if with_no_grad: + with torch.no_grad(): + mel_spec = self.z_to_mel(z) + audio = self.mel_to_audio(mel_spec, scale_dict) + return audio + else: + mel_spec = self.z_to_mel(z, with_no_grad=False) + audio = self.mel_to_audio(mel_spec, scale_dict, with_no_grad=False) + return audio + + def z_to_mel(self,z:Tensor, with_no_grad:bool = True): + if with_no_grad: + with torch.no_grad(): + z = (1.0 / self.scale_factor_z) * z + mel_spec = self.autoencoder.decode(z) + return mel_spec + else: + z = (1.0 / self.scale_factor_z) * z + mel_spec = self.autoencoder.decode(z) + return mel_spec + + def mel_to_audio(self, mel_spec:Tensor, scale_dict:dict = None, with_no_grad:bool = True): + if with_no_grad: + with torch.no_grad(): + mel_spec = mel_spec.permute(0, 1, 3, 2).squeeze(1) + audio = self.autoencoder.vocoder(mel_spec) + if scale_dict is not None: audio = self.denormalize_wav(audio) + return audio + else: + mel_spec = mel_spec.permute(0, 1, 3, 2).squeeze(1) + audio = self.autoencoder.vocoder(mel_spec) + if scale_dict is not None: audio = self.denormalize_wav(audio) + return audio \ No newline at end of file diff --git a/FlashSR/__init__.py b/FlashSR/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b958f7c2362b91d526eba60bd025f86c84a8acc2 --- /dev/null +++ b/README.md @@ -0,0 +1,172 @@ +--- +language: + - en +tags: + - audio + - super-resolution + - speech-enhancement + - diffusion + - one-step +pipeline_tag: audio-to-audio +--- + +# FlashSR: One-step Versatile Audio Super-Resolution + +> **This is a convenience redistribution, not the original repository.** All credit for the model architecture, research, training, and weights belongs to the original authors. This repository is not affiliated with or endorsed by them. + +| | | +|---|---| +| **Authors** | Jaekwon Im and Juhan Nam (KAIST) | +| **Paper** | [FlashSR: One-step Versatile Audio Super-resolution via Diffusion Distillation](https://arxiv.org/abs/2501.10807) (arXiv:2501.10807) | +| **Demo** | [jakeoneijk.github.io/flashsr-demo](https://jakeoneijk.github.io/flashsr-demo/) | +| **Original code** | [jakeoneijk/FlashSR_Inference](https://github.com/jakeoneijk/FlashSR_Inference) | +| **Original weights** | [jakeoneijk/FlashSR_weights](https://huggingface.co/datasets/jakeoneijk/FlashSR_weights) | + +> **Note:** There are other unrelated projects also named "FlashSR" (e.g. for image super-resolution). This repository specifically refers to the **audio** super-resolution model by Im & Nam (2025), based on diffusion distillation for one-step inference. + +## About this repository + +The original code and weights are split across GitHub and Hugging Face and have dependencies (torchcodec, FFmpeg) that can be difficult to set up. This repository bundles everything into one place with a standalone inference script that only needs PyTorch, soundfile, and scipy. + +**What is from the original authors:** The model code (`FlashSR/`, `TorchJaekwon/`) and the pretrained weights (`weights/`) are from the original repositories linked above. + +**What is new in this redistribution:** The inference script (`enhance.py`), `setup.py`, and this README were written independently. The code in this repository (excluding model weights) is released under the **Apache License 2.0**. + +## What FlashSR does + +FlashSR restores high-frequency audio components in a single forward pass. It takes audio at any sample rate, resamples to 48 kHz, and reconstructs missing high-frequency detail. This is useful for: + +- Upscaling low-sample-rate recordings to full bandwidth +- Enhancing audio that has been through lossy processing (codecs, vocoders, etc.) +- Post-processing TTS or voice conversion outputs + +The model handles speech, music, and sound effects. + +## Repository structure + +``` +weights/ + student_ldm.pth (986 MB) - Distilled latent diffusion model + sr_vocoder.pth (599 MB) - Super-resolution vocoder + vae.pth (1.6 GB) - Variational autoencoder +FlashSR/ - Model code (from original repo) +TorchJaekwon/ - Utility library (from original repo) +Assets/ExampleInput/ - Example audio files (speech, music, sound effects) +enhance.py - Standalone inference script +setup.py - Package installer +``` + +## Installation + +**Requirements:** Python 3.10+, PyTorch 2.0+ with CUDA, ~6 GB GPU memory. + +```bash +# Clone this repository +git clone https://huggingface.co/laion/FlashSR_One-step_Versatile_Audio_Super-resolution +cd FlashSR_One-step_Versatile_Audio_Super-resolution + +# Install +pip install -e . +pip install einops librosa soundfile tqdm scipy +``` + +### Verify + +```bash +python enhance.py --input Assets/ExampleInput/speech.wav --output output.wav +``` + +> **Tip:** If you have a conda environment with conflicting cudnn libraries, clear `LD_LIBRARY_PATH` before running: `LD_LIBRARY_PATH="" python enhance.py ...` + +## Usage + +### Command line + +```bash +# Single file +python enhance.py --input my_audio.wav --output enhanced.wav + +# Entire directory +python enhance.py --input ./audio_folder/ --output ./enhanced_folder/ + +# With lowpass filter (can help when input was not originally bandwidth-limited) +python enhance.py --input my_audio.wav --output enhanced.wav --lowpass + +# Specify GPU +CUDA_VISIBLE_DEVICES=0 python enhance.py --input my_audio.wav --output enhanced.wav +``` + +### Python API + +```python +import torch +import soundfile as sf +import numpy as np +from pathlib import Path +from FlashSR.FlashSR import FlashSR + +WEIGHTS_DIR = Path("./weights") +WINDOW_SIZE = 245760 # 5.12 seconds at 48 kHz + +# Initialize +model = FlashSR( + student_ldm_ckpt_path=str(WEIGHTS_DIR / "student_ldm.pth"), + sr_vocoder_ckpt_path=str(WEIGHTS_DIR / "sr_vocoder.pth"), + autoencoder_ckpt_path=str(WEIGHTS_DIR / "vae.pth"), +) +model = model.to("cuda").eval() + +# Load and prepare audio (must be mono, 48 kHz) +samples, rate = sf.read("input.wav", dtype="float32") +if samples.ndim > 1: + samples = samples.mean(axis=1) + +# The model accepts exactly 245760 samples per call. +# Pad short audio; for longer audio, see enhance.py for chunk-based processing. +waveform = torch.from_numpy(samples).unsqueeze(0) # shape: (1, num_samples) +n = waveform.shape[-1] +if n < WINDOW_SIZE: + waveform = torch.nn.functional.pad(waveform, (0, WINDOW_SIZE - n)) + +waveform = waveform.to("cuda") + +with torch.no_grad(): + result = model(waveform, lowpass_input=False) + +# Trim padding and save +result = result[:, :n].squeeze(0).cpu().numpy() +sf.write("output.wav", result, 48000) +``` + +## Notes + +- **Fixed input length:** The model processes exactly 245,760 samples (5.12 seconds at 48 kHz). The `enhance.py` script handles longer audio automatically using overlapping chunks with crossfading. +- **Sample rate:** Input audio at any sample rate is resampled to 48 kHz. Output is always 48 kHz. +- **Channels:** Mono and stereo are both supported. Stereo files are processed channel-by-channel. +- **`lowpass_input` flag:** Set to `True` if your input was not originally bandwidth-limited. This applies a lowpass filter before enhancement to better match the model's training distribution. + +## License + +The inference script (`enhance.py`), `setup.py`, and this README are released under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0). + +The model weights and original model code (`FlashSR/`, `TorchJaekwon/`) are from the original authors' repositories linked above. Please refer to those repositories for their licensing terms. + +## Citation + +If you use FlashSR in your work, please cite the original paper: + +```bibtex +@article{im2025flashsr, + title={FlashSR: One-step Versatile Audio Super-resolution via Diffusion Distillation}, + author={Im, Jaekwon and Nam, Juhan}, + journal={arXiv preprint arXiv:2501.10807}, + year={2025} +} +``` + +## References + +- [AudioSR](https://github.com/haoheliu/versatile_audio_super_resolution) +- [NVSR](https://github.com/haoheliu/ssr_eval) +- [BigVGAN](https://github.com/NVIDIA/BigVGAN) +- [Diffusers](https://github.com/huggingface/diffusers) diff --git a/TorchJaekwon/Controller.py b/TorchJaekwon/Controller.py new file mode 100644 index 0000000000000000000000000000000000000000..32d1fd2b55e2388b7286bcdf23205ee5f2440565 --- /dev/null +++ b/TorchJaekwon/Controller.py @@ -0,0 +1,184 @@ +#type +from typing import Type, Literal, Dict, List, Union, Literal + +#package +import os +import argparse +import numpy as np + +#torchjaekwon +from TorchJaekwon.GetModule import GetModule + +#internal +from HParams import HParams + +class Controller(): + def __init__(self) -> None: + self.set_argparse() + + self.config_name:str = HParams().mode.config_name + self.stage: Literal['preprocess', 'train', 'inference', 'evaluate'] = HParams().mode.stage + + self.config_per_dataset_dict: Dict[str, dict] = HParams().data.config_per_dataset_dict + self.train_mode: Literal['start', 'resume'] = HParams().mode.train + self.train_resume_path: str = HParams().mode.resume_path + self.eval_class_meta:dict = HParams().evaluate.class_meta # {'name': 'Evaluater', 'args': {}} + + def run(self) -> None: + print("=============================================") + print(f"{self.stage} start.") + print("=============================================") + print(f"{self.config_name} start.") + print("=============================================") + + getattr(self,self.stage)() + + print("Finish app.") + + def preprocess(self) -> None: + from TorchJaekwon.DataProcess.Preprocess.Preprocessor import Preprocessor + for data_name in self.config_per_dataset_dict: + for preprocessor_meta in self.config_per_dataset_dict[data_name]['preprocessor_class_meta_list']: + preprocessor_class_name:str = preprocessor_meta['name'] + preprocessor_args:dict = { + 'data_name': data_name, + 'root_dir': HParams().data.root_path, + 'num_workers': HParams().resource.preprocess['num_workers'], + 'device': HParams().resource.device, + } + preprocessor_args.update(preprocessor_meta['args']) + + preprocessor_class:Type[Preprocessor] = GetModule.get_module_class( "./DataProcess/Preprocess", preprocessor_class_name ) + preprocessor:Preprocessor = preprocessor_class(**preprocessor_args) + preprocessor.preprocess_data() + + def train(self) -> None: + import torch + from TorchJaekwon.Train.Trainer.Trainer import Trainer + + train_class_meta:dict = HParams().train.class_meta # {'name': 'Trainer', 'args': {}} + trainer_args:dict = { + 'device': HParams().resource.device, + 'data_class_meta_dict': HParams().pytorch_data.class_meta, + 'model_class_name': HParams().model.class_name, + 'model_class_meta_dict': HParams().model.class_meta_dict, + 'optimizer_class_meta_dict': HParams().train.optimizer['class_meta'], + 'lr_scheduler_class_meta_dict': HParams().train.scheduler['class_meta'], + 'loss_class_meta': HParams().train.loss_dict, + 'max_norm_value_for_gradient_clip': getattr(HParams().train,'max_norm_value_for_gradient_clip',None), + 'total_epoch': getattr(HParams().train, 'total_epoch', int(1e20)), + 'total_step': getattr(HParams().train, 'total_step', np.inf), + 'save_model_every_step': getattr(HParams().train, 'save_model_every_step', None), + 'do_log_every_epoch': getattr(HParams().train, 'do_log_every_epoch', True), + 'seed': (int)(torch.cuda.initial_seed() / (2**32)) if HParams().train.seed is None else HParams().train.seed, + 'seed_strict': HParams().train.seed_strict, + 'debug_mode': getattr(HParams().mode, 'debug_mode', False), + 'use_torch_compile': getattr(HParams().mode, 'use_torch_compile', True), + } + trainer_args.update(train_class_meta['args']) + + trainer_class:Type[Trainer] = GetModule.get_module_class('./Train/Trainer', train_class_meta['name']) + trainer:Trainer = trainer_class(**trainer_args) + trainer.init_train() + + if self.train_mode == "resume": + print('resume the training') + trainer.load_train(self.train_resume_path + "/train_checkpoint.pth") + + trainer.fit() + + def inference(self) -> None: + from TorchJaekwon.Inference.Inferencer.Inferencer import Inferencer + + infer_class_meta:dict = HParams().inference.class_meta # {'name': 'Inferencer', 'args': {}} + inferencer_args:dict = { + 'output_dir': HParams().inference.output_dir, + 'experiment_name': HParams().mode.config_name, + 'model': None, + 'model_class_name': HParams().model.class_name, + 'set_type': HParams().inference.set_type, + 'set_meta_dict': HParams().inference.set_meta_dict, + 'device': HParams().resource.device + } + inferencer_args.update(infer_class_meta['args']) + + inferencer_class:Type[Inferencer] = GetModule.get_module_class("./Inference/Inferencer", infer_class_meta['name']) + inferencer:Inferencer = inferencer_class(**inferencer_args) + inferencer.inference( + pretrained_root_dir = HParams().inference.pretrain_root_dir, + pretrained_dir_name = HParams().mode.config_name if HParams().inference.pretrain_dir == '' else HParams().inference.pretrain_dir, + pretrain_module_name = HParams().inference.pretrain_module_name + ) + + def evaluate(self) -> None: + from TorchJaekwon.Evaluater.Evaluater import Evaluater + evaluater_class:Type[Evaluater] = GetModule.get_module_class("./Evaluater", self.eval_class_meta['name']) + evaluater_args:dict = self.eval_class_meta['args'] + evaluater_args.update({ + 'device': HParams().resource.device + }) + if evaluater_args.get('source_dir','') == '': + source_dir_prefix:str = f'{HParams().inference.output_dir}/{HParams().mode.config_name}' + source_dir_parent:str = '/'.join(source_dir_prefix.split('/')[:-1]) + source_dir_tag:str = source_dir_prefix.split('/')[-1] + source_dir_name_candidate = [dir_name for dir_name in os.listdir(source_dir_parent) if source_dir_tag in dir_name] + source_dir_name_candidate.sort() + evaluater_args['source_dir'] = f'{source_dir_parent}/{source_dir_name_candidate[-1]}' + evaluater:Evaluater = evaluater_class(**evaluater_args) + evaluater.evaluate() + + def set_argparse(self) -> None: + parser = argparse.ArgumentParser() + + parser.add_argument( + "-c", + "--config_path", + type=str, + required=False, + default=None, + help="", + ) + + parser.add_argument( + "-s", + "--stage", + type=str, + required=False, + default=None, + choices = ['preprocess', 'train', 'inference', 'evaluate'], + help="", + ) + + parser.add_argument( + '-r', + '--resume', + help='train resume', + action='store_true' + ) + + parser.add_argument( + "-do", + "--debug_off", + help="debug mode off", + action='store_true' + ) + + parser.add_argument( + "-lv", + "--log_visualizer", + type=str, + required=False, + default=None, + choices = ['tensorboard', 'wandb'], + help="", + ) + + args = parser.parse_args() + + if args.config_path is not None: HParams().set_config(args.config_path) + if args.stage is not None: HParams().mode.stage = args.stage + if args.log_visualizer is not None: HParams().log.visualizer_type = args.log_visualizer + if args.resume: HParams().mode.train = "resume" + if args.debug_off: HParams().mode.debug_mode = False + + return args \ No newline at end of file diff --git a/TorchJaekwon/Data/PytorchDataLoader/BatchSampler/SegmentSampler.py b/TorchJaekwon/Data/PytorchDataLoader/BatchSampler/SegmentSampler.py new file mode 100644 index 0000000000000000000000000000000000000000..0634eac7a9d659f36d4acce28ba7c12ddee63bed --- /dev/null +++ b/TorchJaekwon/Data/PytorchDataLoader/BatchSampler/SegmentSampler.py @@ -0,0 +1,192 @@ +import pickle +from typing import Dict, List, NoReturn + +import numpy as np +from HParams import HParams +from TorchJAEKWON.DataProcess.Util.UtilData import UtilData + +class SegmentSampler: + def __init__( + self, + args_dict: dict + ): + r"""Sample training indexes of sources. + Args: + indexes_path: str, path of indexes dict + input_source_types: list of str, e.g., ['vocals', 'accompaniment'] + target_source_types: list of str, e.g., ['vocals'] + segment_samplers: int + mixaudio_dict, dict, mix-audio data augmentation parameters, + e.g., {'voclas': 2, 'accompaniment': 2} + batch_size: int + steps_per_epoch: int, #steps_per_epoch is called an `epoch` + random_seed: int + """ + self.h_params = HParams() + self.config = args_dict + self.subset = args_dict["subset"] + + self.mix_data_augmentation_num = self.config["mix_data_augmentation_num"] + self.mix_data_augmentation = self.config["mix_data_augmentation"] + self.batch_size = self.h_params.pytorch_data.dataloader[self.subset]["batch_size"] + self.steps_per_epoch = self.config["steps_per_epoch"] + self.source_list:list = self.h_params.preprocess.feature_list + + self.meta_dict = UtilData.pickle_load(self.config["indexes_dict_path"]) + # E.g., { + # 'vocals': [ + # {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}, + # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410}, + # ... (e.g., 225752 dicts) + # ], + # 'accompaniment': [ + # {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}, + # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410}, + # ... (e.g., 225752 dicts) + # ] + # } + + self.pointers_dict = {source_type: 0 for source_type in self.source_list} + # E.g., {'vocals': 0, 'accompaniment': 0} + + self.indexes_dict = { + source_type: np.arange(len(self.meta_dict[source_type])) + for source_type in self.source_list + } + # E.g. { + # 'vocals': [0, 1, ..., 225751], + # 'accompaniment': [0, 1, ..., 225751] + # } + + random_state_for_source_random_seed = np.random.RandomState(self.config["random_seed"]) + self.random_state_dict = {} + + for source_type in self.source_list: + + if self.mix_data_augmentation: + # Use different seeds for different sources. + source_random_seed = random_state_for_source_random_seed.randint(low=0, high=10000) + + else: + # Use same seeds for different sources. + source_random_seed = self.config["random_seed"] + + self.random_state_dict[source_type] = np.random.RandomState( + source_random_seed + ) + + self.random_state_dict[source_type].shuffle(self.indexes_dict[source_type]) + # E.g., [198036, 196736, ..., 103408] + + print("{}: {}".format(source_type, len(self.indexes_dict[source_type]))) + + def __iter__(self) -> List[Dict]: + r"""Yield a batch of meta info. + Returns: + batch_meta_list: (batch_size,) e.g., when mix-audio is 2, looks like [ + {'vocals': [ + {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700}, + {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}] + 'accompaniment': [ + {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 14579460, 'end_sample': 14711760}, + {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 3995460, 'end_sample': 4127760}] + }, + ... + ] + """ + batch_size = self.batch_size + + while True: + batch_meta_dict = {source_type: [] for source_type in self.source_list} + + for source_type in self.source_list: + # E.g., ['vocals', 'accompaniment'] + + # Loop until get a mini-batch. + while len(batch_meta_dict[source_type]) != batch_size: + + if source_type in self.mix_data_augmentation_num.keys(): + mix_audios_num = self.mix_data_augmentation_num[source_type] + + else: + mix_audios_num = 1 + + largest_index = len(self.indexes_dict[source_type]) - mix_audios_num + # E.g., 225750 = 225752 - 2 + + if self.pointers_dict[source_type] > largest_index: + + # Reset pointer, and shuffle indexes. + self.pointers_dict[source_type] = 0 + self.random_state_dict[source_type].shuffle( + self.indexes_dict[source_type] + ) + + source_metas = [] + + for _ in range(mix_audios_num): + + pointer = self.pointers_dict[source_type] + # E.g., 1 + + index = self.indexes_dict[source_type][pointer] + # E.g., 12231 + + self.pointers_dict[source_type] += 1 + + source_meta = self.meta_dict[source_type][index] + # E.g., { + # 'hdf5_path': 'xx/song_A.h5', + # 'key_in_hdf5': 'vocals', + # 'begin_sample': 13406400, + # } + + source_metas.append(source_meta) + + batch_meta_dict[source_type].append(source_metas) + + # When mix-audio is 2, batch_meta_dict looks like: { + # 'vocals': [ + # [{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700}, + # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170} + # ], + # ... (batch_size) + # ] + # 'accompaniment': [ + # [{'hdf5_path': 'songG.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 24232950, 'end_sample': 24365250}, + # {'hdf5_path': 'songH.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1569960, 'end_sample': 1702260} + # ], + # ... (batch_size) + # ] + # } + + batch_meta_list = [ + { + source_type: batch_meta_dict[source_type][i] + for source_type in self.source_list + } + for i in range(batch_size) + ] + # When mix-audio is 2, batch_meta_list looks like: [ + # {'vocals': [ + # {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700}, + # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}] + # 'accompaniment': [ + # {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 14579460, 'end_sample': 14711760}, + # {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 3995460, 'end_sample': 4127760}] + # } + # ... (batch_size) + # ] + + yield batch_meta_list + + def __len__(self) -> int: + return self.steps_per_epoch + + def state_dict(self) -> Dict: + state = {'pointers_dict': self.pointers_dict, 'indexes_dict': self.indexes_dict} + return state + + def load_state_dict(self, state) -> NoReturn: + self.pointers_dict = state['pointers_dict'] + self.indexes_dict = state['indexes_dict'] \ No newline at end of file diff --git a/TorchJaekwon/Data/PytorchDataLoader/PytorchDataLoader.py b/TorchJaekwon/Data/PytorchDataLoader/PytorchDataLoader.py new file mode 100644 index 0000000000000000000000000000000000000000..477da9a01335b3c978486335703d80a8c3e2c958 --- /dev/null +++ b/TorchJaekwon/Data/PytorchDataLoader/PytorchDataLoader.py @@ -0,0 +1,61 @@ +from typing import Dict +from torch.utils.data import Dataset, DataLoader + +from HParams import HParams +from TorchJaekwon.GetModule import GetModule + +class PytorchDataLoader: + def __init__(self): + self.h_params = HParams() + self.data_loader_config:dict = self.h_params.pytorch_data.dataloader + + def get_pytorch_data_loaders(self) -> Dict[str,DataLoader]: #subset,dataloader + pytorch_dataset_dict:Dict[str,Dataset] = self.get_pytorch_data_set_dict() #key: subset, value: dataset + pytorch_data_loader_config_dict:dict = self.get_pytorch_data_loader_args(pytorch_dataset_dict) + pytorch_data_loader_dict:Dict[str,DataLoader] = self.get_pytorch_data_loaders_from_config(pytorch_data_loader_config_dict) + return pytorch_data_loader_dict + + def get_pytorch_data_set_dict(self) -> Dict[str,Dataset]: + pytorch_dataset_dict:Dict[str,Dataset] = dict() + for subset in self.data_loader_config: + dataset_args:dict = self.data_loader_config[subset]["dataset"]['class_meta']['args'] + pytorch_dataset_dict[subset] = GetModule.get_module_class('./Data/PytorchDataset',self.data_loader_config[subset]["dataset"]['class_meta']["name"])(**dataset_args) + return pytorch_dataset_dict + + def get_pytorch_data_loader_args(self,pytorch_dataset:dict) -> dict: + pytorch_data_loader_config_dict:dict = {subset:dict() for subset in pytorch_dataset} + + for subset in pytorch_dataset: + args_exception_list = self.get_exception_list_of_dataloader_parameters(subset) + pytorch_data_loader_config_dict[subset]["dataset"] = pytorch_dataset[subset] + for arg_name in self.data_loader_config[subset]: + if arg_name in args_exception_list: + continue + if arg_name == 'batch_sampler': + arguments_for_args_class:dict = self.h_params.pytorch_data.dataloader[subset]['batch_sampler'] + arguments_for_args_class.update({"pytorch_dataset":pytorch_dataset[subset],"subset":subset}) + pytorch_data_loader_config_dict[subset][arg_name] = GetModule.get_module_class('./Data/PytorchDataLoader', + self.data_loader_config[subset][arg_name]["class_name"] + )(arguments_for_args_class) + elif arg_name == 'collate_fn': + if self.data_loader_config[subset][arg_name] == True: pytorch_data_loader_config_dict[subset][arg_name] = pytorch_data_loader_config_dict[subset]["dataset"].collate_fn + else: + pytorch_data_loader_config_dict[subset][arg_name] = self.data_loader_config[subset][arg_name] + + return pytorch_data_loader_config_dict + + def get_exception_list_of_dataloader_parameters(self,subset): + args_exception_list = ["dataset"] + if "batch_sampler" in self.data_loader_config[subset]: + args_exception_list += ["batch_size", "shuffle", "sampler", "drop_last"] + return args_exception_list + + def get_pytorch_data_loaders_from_config(self,dataloader_config:dict) -> dict: + pytorch_data_loader_dict = dict() + for subset in dataloader_config: + pytorch_data_loader_dict[subset] = DataLoader(**dataloader_config[subset]) + return pytorch_data_loader_dict + + + + \ No newline at end of file diff --git a/TorchJaekwon/Data/PytorchDataset/DataSet.py b/TorchJaekwon/Data/PytorchDataset/DataSet.py new file mode 100644 index 0000000000000000000000000000000000000000..472c93bba6b7ee28ff958bc6137dff6f7d7ff70f --- /dev/null +++ b/TorchJaekwon/Data/PytorchDataset/DataSet.py @@ -0,0 +1,22 @@ +import torch.utils.data.dataset as dataset +import pickle + +class DataSet(dataset.Dataset): + + def __init__(self, config: dict): + data_path_list = config["data_path_list"] + self.data_set_type = config["subset"] + self.files = [] + for fname in data_path_list: + self.files.append(self.read_data(fname)) + + def read_data(self, data_path): + with open(data_path, 'rb') as pickle_file: + file_data_dict = pickle.load(pickle_file) + return file_data_dict + + def __len__(self): + return len(self.files) + + def __getitem__(self, index): + return self.files[index] \ No newline at end of file diff --git a/TorchJaekwon/Data/PytorchDataset/EvenSampleFromMultipleDataset.py b/TorchJaekwon/Data/PytorchDataset/EvenSampleFromMultipleDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d4f46bb3c42398cd3bbca72e6d3a2c3c23114e42 --- /dev/null +++ b/TorchJaekwon/Data/PytorchDataset/EvenSampleFromMultipleDataset.py @@ -0,0 +1,58 @@ +from typing import Union, Dict +from numpy import ndarray + +import time +import random +import numpy as np +import torch +from torch.utils.data import IterableDataset +from torch.utils.data.dataset import Dataset + +class EvenSampleFromMultipleDataset(IterableDataset): + def __init__(self, + is_multiple_random_seed:bool = True, + random_seed:int = (int)(torch.cuda.initial_seed() / (2**32)) + ) -> None: + self.data_list_dict: Dict[str,list] = self.init_data_list_dict() # {data_type1: List, data_type2: List} + self.data_set_class_list = list(self.data_list_dict.keys()) + + self.idx_dict = {data_name: 0 for data_name in self.data_list_dict} + self.idx_dict['data_class'] = 0 + + self.random_state_dict = dict() + self.random_state_dict['data_class'] = np.random.RandomState(random_seed) + self.random_state_dict['data_class'].shuffle(self.data_set_class_list) + + for data_name in self.data_list_dict: + random_seed_for_data = np.random.RandomState(random_seed).randint(low=0, high=10000) if is_multiple_random_seed else random_seed + self.random_state_dict[data_name] = np.random.RandomState(random_seed_for_data) + self.random_state_dict[data_name].shuffle(self.data_list_dict[data_name]) + print("{}: {}".format(data_name, len(self.data_list_dict[data_name]))) + + def init_data_list_dict(self) -> Dict[str,list]: # {data_type1: List, data_type2: List} + pass + + def read_data(self,meta_data): + pass + + def __iter__(self): + while True: + self.idx_dict['data_class'] = self.idx_dict['data_class'] + 1 + if self.idx_dict['data_class'] == len(self.data_set_class_list): + self.idx_dict['data_class'] = 0 + self.random_state_dict['data_class'].shuffle(self.data_set_class_list) + data_class:str = self.data_set_class_list[self.idx_dict['data_class']] + + self.idx_dict[data_class] = self.idx_dict[data_class] + 1 + if self.idx_dict[data_class] == len(self.data_list_dict[data_class]): + self.idx_dict[data_class] = 0 + self.random_state_dict[data_class].shuffle(self.data_list_dict[data_class]) + + data = self.read_data(self.data_list_dict[data_class][self.idx_dict[data_class]]) + + yield data + + def __len__(self): + return max([len(self.data_list_dict[data_name]) for data_name in self.data_list_dict]) + + diff --git a/TorchJaekwon/Data/PytorchDataset/IterablePytorchDataset.py b/TorchJaekwon/Data/PytorchDataset/IterablePytorchDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0e978b349198a8f1bab2bf0d49034b08dbb61233 --- /dev/null +++ b/TorchJaekwon/Data/PytorchDataset/IterablePytorchDataset.py @@ -0,0 +1,67 @@ +from typing import Union, Dict +from numpy import ndarray + +import time +import random +import numpy as np +import torch +from torch.utils.data import IterableDataset + +class IterablePytorchDataset(IterableDataset): + def __init__(self, + data_type:Union[str,list] = None, + is_multiple_random_seed:bool = True, + random_seed:int = (int)(torch.cuda.initial_seed() / (2**32)) + ) -> None: + + self.is_multiple_data_type:bool = isinstance(data_type,list) + + if self.is_multiple_data_type: + self.data_type_list:list = data_type + self.data_dict:Dict[str,ndarray] = self.init_data_dict() + self.indexes_dict = { data_name: np.arange(len(self.data_dict[data_name])) for data_name in data_type } + self.pointers_dict = {data_name: 0 for data_name in data_type} + + self.random_state_dict = {} + for data_name in data_type: + random_seed_for_data = np.random.RandomState(random_seed).randint(low=0, high=10000) if is_multiple_random_seed else random_seed + self.random_state_dict[data_name] = np.random.RandomState(random_seed_for_data) + self.random_state_dict[data_name].shuffle(self.indexes_dict[data_name]) + print("{}: {}".format(data_name, len(self.indexes_dict[data_name]))) + else: + print('make data list') + start = time.time() + self.data_list = self.init_data_list() + print(f"make data list: took {time.time() - start:.5f} sec") + self.index:int = 0 + random.shuffle(self.data_list) + + def init_data_list(self) -> list: + pass + + def init_data_dict(self) -> Dict[str,ndarray]: # {data_type1: List, data_type2: List} + pass + + def __iter__(self): + while True: + data = self.get_data(self) + yield data + + def get_data(self): + if self.is_multiple_data_type: + data_dict = dict() + for data_name in self.data_type_list: + if self.pointers_dict[data_name] >= len(self.indexes_dict[data_name]): + self.pointers_dict[data_name] = 0 + self.random_state_dict[data_name].shuffle( self.indexes_dict[data_name] ) + data_dict[data_name] = self.read_data(self.data_dict[self.indexes_dict[data_name][self.pointers_dict[data_name]]]) + return data_dict + else: + self.index = self.index + 1 + if self.index == len(self.data_list): + self.index = 0 + random.shuffle(self.data_list) + return self.read_data(self.data_list[self.index]) + + def read_data(self,meta_data): + pass diff --git a/TorchJaekwon/DataProcess/Preprocess/MakeMetaDataScale.py b/TorchJaekwon/DataProcess/Preprocess/MakeMetaDataScale.py new file mode 100644 index 0000000000000000000000000000000000000000..fcad830fb3de86a224e44ec2c304da92ef3a40ec --- /dev/null +++ b/TorchJaekwon/DataProcess/Preprocess/MakeMetaDataScale.py @@ -0,0 +1,84 @@ +import os +from DataProcess.MakeMetaData.MakeMetaData import MakeMetaData +import sklearn.preprocessing +import copy +import tqdm +import numpy as np +from HParams import HParams +from GetModule import GetModule +from Data.PytorchDataLoader.PytorchDataLoader import PytorchDataLoader +from DataProcess.Process.Process import Process + +class MakeMetaDataScale(MakeMetaData): + + def __init__(self,h_params:HParams, make_meta_data_config:dict) -> None: + super().__init__(h_params,make_meta_data_config) + + self.get_module = GetModule() + self.data_loader_loader:PytorchDataLoader = self.get_module.get_module('pytorch_dataLoader',self.h_params.pytorch_data.name,self.h_params) + + if self.h_params.process.name is not None: + self.data_processor:Process = self.get_module.get_module("process",self.h_params.process.name, {"h_params":self.h_params},arg_unpack=True) + else: + self.data_processor = None + + def make_meta_data(self): + train_data_path = self.data_loader_loader.get_data_path_dict()["train"] + pytorch_data_set = self.get_pytorch_data_set(train_data_path=train_data_path) + result = self.get_statistics(pytorch_data_set) + print("end") + + def get_pytorch_data_set(self,train_data_path) -> dict: + dataset_config = self.h_params.make_meta_data.make_meta_data_dict["MakeMetaDataScale"]["dataset"] + config_for_dataset = { + "h_params": self.h_params, + "data_path_list": train_data_path, + "subset": "train", + "data_set_config": dataset_config + } + pytorch_dataset = self.get_module.get_module("pytorch_dataset", dataset_config["name"],config_for_dataset) + return pytorch_dataset + + def get_statistics(self, dataset): + standard_scaler = dict() + minmax_scaler = dict() + + pbar = tqdm.tqdm(range(len(dataset)), disable=False) + for ind in pbar: + feature_dict = dataset[ind] + for feature_name in feature_dict: + feature_dict[feature_name] = feature_dict[feature_name].astype(np.float32) + pbar.set_description("Compute dataset statistics") + + if self.data_processor is not None: + dataset_config = self.h_params.pytorch_data.dataloader["train"]['dataset'] + train_data_name_dict = dict() + train_data_name_dict["input_name"] = dataset_config["train_source_name_dict"]["input"] + train_data_name_dict["target_name"] = dataset_config["train_source_name_dict"]["target"] + train_data_dict = self.data_processor.preprocess_training_data(feature_dict,additional_dict=train_data_name_dict) + else: + train_data_dict = feature_dict + + for feature_name in train_data_dict: + if feature_name not in standard_scaler: + standard_scaler[feature_name] = sklearn.preprocessing.StandardScaler() + minmax_scaler[feature_name] = sklearn.preprocessing.MinMaxScaler() + + train_data_dict[feature_name] = train_data_dict[feature_name].squeeze() + train_data_dict[feature_name] = np.transpose(train_data_dict[feature_name],(0,2,1)) + train_data_dict[feature_name] = train_data_dict[feature_name].reshape(-1,train_data_dict[feature_name].shape[-1]) + + standard_scaler[feature_name].partial_fit(train_data_dict[feature_name]) + minmax_scaler[feature_name].partial_fit(train_data_dict[feature_name]) + + result = dict() + for feature_name in standard_scaler: + result[feature_name] = dict() + result[feature_name]["mean"] = standard_scaler[feature_name].mean_ + result[feature_name]["std"] = np.maximum(standard_scaler[feature_name].scale_, 1e-4 * np.max(standard_scaler[feature_name].scale_)) + result[feature_name]["max_by_bin"] = minmax_scaler[feature_name].data_max_ + result[feature_name]["min_by_bin"] = minmax_scaler[feature_name].data_min_ + result[feature_name]["max"] = np.max(result[feature_name]["max_by_bin"]) + result[feature_name]["min"] = np.min(result[feature_name]["min_by_bin"]) + + return result \ No newline at end of file diff --git a/TorchJaekwon/DataProcess/Preprocess/MakeMetaDataSegmentIndexByFeatureType.py b/TorchJaekwon/DataProcess/Preprocess/MakeMetaDataSegmentIndexByFeatureType.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc196b213e3a5d7698788bd96d8897409942811 --- /dev/null +++ b/TorchJaekwon/DataProcess/Preprocess/MakeMetaDataSegmentIndexByFeatureType.py @@ -0,0 +1,75 @@ +import os +import pickle + +from HParams import HParams +from DataProcess.MakeMetaData.MakeMetaData import MakeMetaData + +class MakeMetaDataSegmentIndexByFeatureType(MakeMetaData): + r"""Create and write out training indexes into disk. The indexes may contain + information from multiple datasets. During training, training indexes will + be shuffled and iterated for selecting segments to be mixed. E.g., the + training indexes_dict looks like: { + 'audio_vocals': [ + {'name':..} + ... + ] + 'audio_accompaniment': [ + {'name':..} + ... + ] + } + """ + + def __init__(self, h_params:HParams, make_meta_data_config:dict) -> None: + super().__init__(h_params,make_meta_data_config) + config:dict = make_meta_data_config + + self.feature_list = config["feature_list"] + self.config_of_subset_dict:dict = config["config_of_subset"] + + self.sample_rate:int = self.h_params.preprocess.sample_rate + self.file_ext = ".pkl" + + self.result_file_name = config["result_file_name"] + + def make_meta_data(self): + for subset in self.config_of_subset_dict: + segment_index_dict:dict = {feature_type: [] for feature_type in self.feature_list} + segment_samples_length = int(self.h_params.make_meta_data.segment_seconds * self.sample_rate) + + if "hopsize" in self.config_of_subset_dict[subset]: + segment_samples_hop_size:int = self.config_of_subset_dict[subset]["hopsize"] + else: + segment_samples_hop_size = int(self.config_of_subset_dict[subset]["hop_seconds"] * self.sample_rate) + + for feature_type in segment_index_dict: + print("--- {} ---".format(feature_type)) + segment_data_num = 0 + + for data_root_path in self.data_root_path_list: + data_path = os.path.join(data_root_path,subset) + data_name_list = sorted(os.listdir(data_path)) + + for i, data_name in enumerate(data_name_list): + print(f"{feature_type} of {data_name} ({i+1} / {len(data_name_list)})") + file_path = os.path.join(os.path.join(data_path,data_name),feature_type) + self.file_ext + with open(file_path, 'rb') as pickle_file: + feature = pickle.load(pickle_file) + + begin_sample = 0 + while begin_sample + segment_samples_length < feature.shape[-1]: + segment_index_dict[feature_type].append( + { + "name": data_name, + "data_path": file_path, + "feature_type": feature_type, + "begin_sample":begin_sample, + "end_sample": begin_sample + segment_samples_length + }) + begin_sample += segment_samples_hop_size + segment_data_num += 1 + print("{} indexes: {}".format(data_root_path, segment_data_num)) + print( "Total indexes for {}: {}".format(feature_type, len(segment_index_dict[feature_type]))) + + pickle.dump(segment_index_dict, open(os.path.join(self.h_params.data.root_path,self.result_file_name), "wb")) + print("Write index dict to {}".format(os.path.join(self.h_params.data.root_path,self.result_file_name))) \ No newline at end of file diff --git a/TorchJaekwon/DataProcess/Preprocess/Preprocessor.py b/TorchJaekwon/DataProcess/Preprocess/Preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..7c084e7867299f1dd46861b68e1f2443c8e47284 --- /dev/null +++ b/TorchJaekwon/DataProcess/Preprocess/Preprocessor.py @@ -0,0 +1,66 @@ +from typing import List + +from abc import ABC, abstractmethod +from concurrent.futures import ProcessPoolExecutor +import os +import time +import torch +from tqdm import tqdm + +class Preprocessor(ABC): + def __init__(self, + data_name:str = None, + root_dir:str = None, + device:torch.device = None, + num_workers:int = 1, + ) -> None: + # args to class variable + self.data_name:str = data_name + self.root_dir:str = root_dir + self.num_workers:int = num_workers + self.device:torch.device = device + if self.root_dir is not None and self.data_name is not None: + self.output_dir = self.get_output_dir() + os.makedirs(self.output_dir,exist_ok=True) + else: + print('Warning: root_dir or data_name is None') + + def get_output_dir(self) -> str: + return os.path.join(self.root_dir, self.data_name) + + def write_message(self,message_type:str,message:str) -> None: + with open(f"{self.preprocessed_data_path}/{message_type}.txt",'a') as file_writer: + file_writer.write(message+'\n') + + def preprocess_data(self) -> None: + meta_param_list:list = self.get_meta_data_param() + if meta_param_list is None: + print('meta_param_list is None, So we skip preprocess data') + return + start_time:float = time.time() + if self.num_workers > 2: + with ProcessPoolExecutor(max_workers=self.num_workers) as pool: + pool.map(self.preprocess_one_data, meta_param_list) + else: + for meta_param in tqdm(meta_param_list,desc='preprocess data'): + self.preprocess_one_data(meta_param) + + self.final_process() + print("{:.3f} s".format(time.time() - start_time)) + + @abstractmethod + def get_meta_data_param(self) -> list: + ''' + meta_data_param_list = list() + ''' + raise NotImplementedError + + @abstractmethod + def preprocess_one_data(self,param: tuple) -> None: + ''' + ex) (subset, file_name) = param + ''' + raise NotImplementedError + + def final_process(self) -> None: + print("Finish preprocess") \ No newline at end of file diff --git a/TorchJaekwon/Evaluater/Evaluater.py b/TorchJaekwon/Evaluater/Evaluater.py new file mode 100644 index 0000000000000000000000000000000000000000..a01387f20c44125110228f8c9232075b3f18d066 --- /dev/null +++ b/TorchJaekwon/Evaluater/Evaluater.py @@ -0,0 +1,87 @@ +#type +from typing import List,Dict,Union +#import +import os +from tqdm import tqdm +import numpy as np +import torch +#torchjaekwon import +from TorchJaekwon.Util.UtilData import UtilData +#internal import + +class Evaluater(): + def __init__(self, + source_dir:str, + reference_dir:str = None, + sort_result_by_metric:bool = True, + device:torch.device = torch.device('cpu') + ) -> None: + self.source_dir:str = source_dir + self.reference_dir:str = reference_dir + self.sort_result_by_metric = sort_result_by_metric + self.device:torch.device = device + + ''' + ============================================================== + abstract method start + ============================================================== + ''' + def get_eval_dir_list(self) -> List[str]: + return [self.source_dir] + + def get_meta_data_list(self, eval_dir:str) -> List[dict]: + pass + + def get_result_dict_for_one_testcase( + self, + meta_data:dict + ) -> dict: #{'name':name_of_testcase,'metric_name1':value1,'metric_name2':value2... } + pass + ''' + ============================================================== + abstract method end + ============================================================== + ''' + def evaluate(self) -> None: + eval_dir_list:List[str] = self.get_eval_dir_list() + + evaluation_result_dir:str = f"{self.source_dir}/_evaluation" + os.makedirs(evaluation_result_dir,exist_ok=True) + + for eval_dir in tqdm(eval_dir_list, desc='evaluate eval dir'): + meta_data_list: List[dict] = self.get_meta_data_list(eval_dir) + result_dict:dict = self.get_result_dict(meta_data_list) + result_dict['statistic'].update(self.set_eval(eval_dir=eval_dir)) + + test_set_name:str = eval_dir.split('/')[-1] + UtilData.yaml_save(f'{evaluation_result_dir}/{test_set_name}_mean_median_std.yaml',result_dict['statistic']) + if self.sort_result_by_metric: + for metric_name in result_dict['metric_name_list']: + UtilData.yaml_save(f'{evaluation_result_dir}/{test_set_name}_sort_by_{metric_name}.yaml',UtilData.sort_dict_list( dict_list = result_dict['result'], key = metric_name)) + + def get_result_dict(self,meta_data_list:List[dict]) -> dict: + result_dict_list:List[dict] = list() + for meta_data in tqdm(meta_data_list,desc='get result'): + result_dict_list.append(self.get_result_dict_for_one_testcase(meta_data)) + + metric_name_list:list = [metric_name for metric_name in list(result_dict_list[0].keys()) if type(result_dict_list[0][metric_name]) in [float,np.float_]] + metric_name_list.sort() + mean_median_std_dict:dict = self.get_mean_median_std_from_dict_list(result_dict_list,metric_name_list) + + return {'metric_name_list': metric_name_list, 'result':result_dict_list, 'statistic':mean_median_std_dict} + + def get_mean_median_std_from_dict_list(self,dict_list:List[dict],metric_name_list:List[str]): + result_list_dict:dict = {metric_name: list() for metric_name in metric_name_list} + for result in dict_list: + for metric_name in metric_name_list: + result_list_dict[metric_name].append(result[metric_name]) + result_dict = dict() + for metric_name in metric_name_list: + result_dict[metric_name] = dict() + result_dict[metric_name]['mean'] = float(np.mean(result_list_dict[metric_name])) + result_dict[metric_name]['median'] = float(np.median(result_list_dict[metric_name])) + result_dict[metric_name]['std'] = float(np.std(result_list_dict[metric_name])) + return result_dict + + def set_eval(self, eval_dir:str) -> dict: #key: metric_name + return dict() \ No newline at end of file diff --git a/TorchJaekwon/Evaluater/EvaluaterSVS.py b/TorchJaekwon/Evaluater/EvaluaterSVS.py new file mode 100644 index 0000000000000000000000000000000000000000..bb78c23ea8429f3e46e9912e8239dd4ca6fd8820 --- /dev/null +++ b/TorchJaekwon/Evaluater/EvaluaterSVS.py @@ -0,0 +1,129 @@ +from HParams import HParams +from Evaluater.Evaluater import Evaluater +import soundfile as sf +from scipy.io import wavfile +import numpy as np +import museval +import os +import librosa +from scipy.spatial.distance import euclidean +import pysptk +from fastdtw import fastdtw +import Evaluater.MetricVoice as vm +from Evaluater.MetricVoice import MetricVoice + +class EvaluaterSVS(Evaluater): + def __init__(self, h_params: HParams): + super().__init__(h_params) + self.pred_gt_name_dict = h_params.evaluate.pred_gt_dict + self.voice_metric = MetricVoice(self.h_params) + + def read_pred_gt_list(self,data_name,read_module_name="librosa"): + ''' + "key": Pred", gt: + ["key",.."key"...] + ["key":sdr] + return {"pred": list, "gt" : list} + ''' + file_path = f"{self.data_path}/{data_name}" + + pred_gt_dict = dict() + + for data_name in self.pred_gt_name_dict: + pred_gt_dict[data_name] = dict() + if read_module_name == "soundfile": + pred_gt_dict[data_name]["gt"],sr = sf.read(f"{file_path}/{self.pred_gt_name_dict[data_name]['gt_audio_file_name']}") + pred_gt_dict[data_name]["pred"],sr = sf.read(f"{file_path}/{self.pred_gt_name_dict[data_name]['pred_audio_file_name']}") + elif read_module_name == "librosa": + pred_gt_dict[data_name]["gt"],sr = librosa.load(f"{file_path}/{self.pred_gt_name_dict[data_name]['gt_audio_file_name']}",sr=None) + pred_gt_dict[data_name]["pred"],sr = librosa.load(f"{file_path}/{self.pred_gt_name_dict[data_name]['pred_audio_file_name']}",sr=None) + assert pred_gt_dict[data_name]["gt"].shape == pred_gt_dict[data_name]["pred"].shape, "pred shape and gt shape shoud be same" + return pred_gt_dict + + def evaluator(self,test_set_dict) -> dict: + ''' + references : np.ndarray, shape=(nsrc, nsampl, nchan) + array containing true reference sources + estimates : np.ndarray, shape=(nsrc, nsampl, nchan) + array containing estimated sources + + return evaluation resutl + ''' + final_evaluation_dict = dict() + + data_name_list = [] + references_list = [] #gt + estimates_list = [] #pred + for data_name in test_set_dict: + final_evaluation_dict[data_name] = dict() + data_name_list.append(data_name) + references_list.append(test_set_dict[data_name]['gt']) + estimates_list.append(test_set_dict[data_name]['pred']) + if 'voice' in data_name: + print("get_mcd") + + mcd, length =vm.get_mcd( source=test_set_dict[data_name]['pred'], + target=test_set_dict[data_name]['gt'], + sample_rate=self.h_params.preprocess.sample_rate) + final_evaluation_dict[data_name]["MCD"] = mcd + sispnr = self.voice_metric.get_sispnr(pred_audio=test_set_dict[data_name]['pred'],target_audio=test_set_dict[data_name]['gt']) + final_evaluation_dict[data_name].update(sispnr) + sdr_torchmetrics:dict = self.voice_metric.get_sdr_torchmetrics(pred_audio=test_set_dict[data_name]['pred'],target_audio=test_set_dict[data_name]['gt']) + final_evaluation_dict[data_name].update(sdr_torchmetrics) + + print("get SDR, ISR, SIR, SAR") + SDR, ISR, SIR, SAR = museval.evaluate(references=references_list,estimates=estimates_list) + evaluation_dict = {"SDR":SDR,"ISR":ISR,"SIR":SIR,"SAR":SAR} + + for i,data_name in enumerate(data_name_list): + for metric in evaluation_dict: + evaluation_metric = evaluation_dict[metric][i] + evaluation_metric = evaluation_metric[np.isfinite(evaluation_metric)] + final_evaluation_dict[data_name][metric] = np.median(evaluation_metric) + return final_evaluation_dict + + + def mcd_eval(self,ref_vocal, estimate_vocal): + _logdb_const = 10.0 / np.log(10.0) * np.sqrt(2.0) + + mgc1 = self.readmgc(ref_vocal) + mgc2 = self.readmgc(estimate_vocal) + + x = mgc1 + y = mgc2 + + distance, path = fastdtw(x, y, dist=euclidean) + + distance/= (len(x) + len(y)) + pathx = list(map(lambda l: l[0], path)) + pathy = list(map(lambda l: l[1], path)) + x, y = x[pathx], y[pathy] + + frames = x.shape[0] + + z = x - y + s = np.sqrt((z * z).sum(-1)).sum() + + return (_logdb_const * float(s) / float(frames)) + + def readmgc(self, audio_data): + print("readmgc") + x = self.util.change_audio_data_type_float_to_int(audio_data) + if x.ndim == 2: + x = x[:,1] + frame_length = 1024 + hop_length = 256 + # Windowing + print("windowing") + frames = librosa.util.frame(x, frame_length=frame_length, hop_length=hop_length).astype(np.float64).T + frames *= pysptk.blackman(frame_length) + assert frames.shape[1] == frame_length + # Order of mel-cepstrum + order = 25 + alpha = 0.41 + stage = 5 + gamma = -1.0 / stage + print("get mgcep") + mgc = pysptk.mgcep(frames, order, alpha, gamma) + mgc = mgc.reshape(-1, order + 1) + return mgc \ No newline at end of file diff --git a/TorchJaekwon/Evaluater/MetricSound.py b/TorchJaekwon/Evaluater/MetricSound.py new file mode 100644 index 0000000000000000000000000000000000000000..3d9653398284434646ddf979ea44ee87ab66cd62 --- /dev/null +++ b/TorchJaekwon/Evaluater/MetricSound.py @@ -0,0 +1,90 @@ +from numpy import ndarray +from torch import Tensor + +import numpy as np +import librosa +import torch +import torchaudio +from TorchJaekwon.Evaluater.Package.pysepm.qualityMeasures import fwSNRseg + +from TorchJaekwon.Util.UtilData import UtilData +from TorchJaekwon.Util.UtilAudio import UtilAudio + +class MetricSound: + @staticmethod + def fwssnr(target_wave:Tensor, #[batch, channel, signal_length] + pred_wave:Tensor, #[batch, channel, signal_length] + batch_size:int = 1, + sample_rate:int = 16000): + ''' + the frequency-weighted segmental SNR (fwSSNR) + Y. Hu and Philipos C. Loizou, “Evaluation of objective quality measures for speech enhancement,” IEEE TASLP, 2008. + ''' + target_wave = UtilData.fit_shape_length(target_wave, 3) + pred_wave = UtilData.fit_shape_length(pred_wave, 3) + target_wave, pred_wave = MetricSound.trim_samples(target_wave, pred_wave) + target_wave = UtilAudio.normalize_by_fro_norm(target_wave) + pred_wave = UtilAudio.normalize_by_fro_norm(pred_wave) + + fwSNRseg_vectorized = np.vectorize(fwSNRseg, signature='(n),(n),()->()') + values = [] + for i in range(0, target_wave.shape[0], batch_size): + target_batch = target_wave[i:i+batch_size, 0, :].detach().cpu().numpy() + pred_batch = pred_wave[i:i+batch_size, 0, :].detach().cpu().numpy() + batch_values = fwSNRseg_vectorized(target_batch, pred_batch, fs = sample_rate) + values.extend(batch_values.tolist()) + return float(np.mean(values)) + + def multi_resolution_spectrogram_distance(target_wave:Tensor, #[batch, channel, signal_length] + pred_wave:Tensor, #[batch, channel, signal_length] + batch_size:int = 1): + target_wave = UtilData.fit_shape_length(target_wave, 3) + pred_wave = UtilData.fit_shape_length(pred_wave, 3) + target_wave, pred_wave = MetricSound.trim_samples(target_wave, pred_wave) + target_wave = UtilAudio.normalize_by_fro_norm(target_wave) + pred_wave = UtilAudio.normalize_by_fro_norm(pred_wave) + + n_ffts = [512, 1024, 2048] + hop_lengths = [128, 256, 512] + spec_transforms = [torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=1) for n_fft, hop_length in zip(n_ffts, hop_lengths)] + values = [] + for i in range(0, target_wave.shape[0], batch_size): + clean_batch = target_wave[i:i+batch_size] + estimated_batch = pred_wave[i:i+batch_size] + clean_specs = [(spec_transform(clean_batch) + 1e-10).reshape(clean_batch.shape[0], -1) for spec_transform in spec_transforms] + estimated_specs = [(spec_transform(estimated_batch) + 1e-10).reshape(estimated_batch.shape[0], -1) for spec_transform in spec_transforms] + + losses_sc = [(torch.square(clean_spec - estimated_spec).sum(1) / torch.square(clean_spec).sum(1)) for clean_spec, estimated_spec in zip(clean_specs, estimated_specs)] + losses_mag = [(clean_spec.log() - estimated_spec.log()).abs().mean(dim=1) for clean_spec, estimated_spec in zip(clean_specs, estimated_specs)] + losses = [(loss_sc + loss_mag).detach().cpu().numpy() for loss_sc, loss_mag in zip(losses_sc, losses_mag)] + loss_batch = np.mean(losses, axis=0) + values.extend(loss_batch.tolist()) + return float(np.mean(values)) + + def spectrogram_l1(target_wave:Tensor, #[batch, channel, signal_length] + pred_wave:Tensor, #[batch, channel, signal_length] + batch_size:int = 1): + target_wave = UtilData.fit_shape_length(target_wave, 3) + pred_wave = UtilData.fit_shape_length(pred_wave, 3) + target_wave, pred_wave = MetricSound.trim_samples(target_wave, pred_wave) + target_wave = UtilAudio.normalize_by_fro_norm(target_wave) + pred_wave = UtilAudio.normalize_by_fro_norm(pred_wave) + + melspec_transform = torchaudio.transforms.Spectrogram(n_fft=1024, hop_length=256) + values = [] + for i in range(0, target_wave.shape[0], batch_size): + clean_batch = target_wave[i:i+batch_size] + estimated_batch = pred_wave[i:i+batch_size] + clean_spec = (melspec_transform(clean_batch) + 1e-10).log() + estimated_spec = (melspec_transform(estimated_batch) + 1e-10).log() + clean_spec, estimated_spec = clean_spec.reshape(clean_spec.shape[0], -1), estimated_spec.reshape(estimated_spec.shape[0], -1) + values.extend((clean_spec - estimated_spec).abs().mean(dim=1).detach().cpu().numpy().tolist()) + return float(np.mean(values)) + + def signal_to_noise(self,pred_waveform:ndarray,target_waveform:ndarray) -> float: + return 10.*np.log10(np.sqrt(np.sum(target_waveform**2))/np.sqrt(np.sum((target_waveform - pred_waveform)**2))) + + @staticmethod + def trim_samples(sample1, sample2): + min_length = min(sample1.shape[2], sample2.shape[2]) + return sample1[:,:,:min_length], sample2[:,:,:min_length] \ No newline at end of file diff --git a/TorchJaekwon/Evaluater/MetricVoice.py b/TorchJaekwon/Evaluater/MetricVoice.py new file mode 100644 index 0000000000000000000000000000000000000000..842de1dc6335afd8f909c5958fab8adc96efc5b0 --- /dev/null +++ b/TorchJaekwon/Evaluater/MetricVoice.py @@ -0,0 +1,250 @@ +#type +from typing import Optional,Dict +from torch import Tensor +from numpy import ndarray +#import +import numpy as np +import torch +try: import pyworld as pw +except: print('Warning: pyworld is not installed') +try: import pysptk +except: print('Warning: pysptk is not installed') +try: from pesq import pesq +except: print('Warning: pesq is not installed') +try: from fastdtw import fastdtw +except: print('Warning: fastdtw is not installed') +try: from skimage.metrics import structural_similarity as ssim +except: print('Warning: skimage is not installed') +#torchjaekwon import +from TorchJaekwon.Util.UtilAudioMelSpec import UtilAudioMelSpec +from TorchJaekwon.Util.UtilAudio import UtilAudio +#internal import + +class MetricVoice: + def __init__(self, + sample_rate:int = 16000, + nfft:Optional[int] = None, + hop_size:Optional[int] = None, + mel_size:Optional[int] = None, + frequency_min:Optional[float] = None, + frequency_max:Optional[float] = None) -> None: + + spec_config_of_sr_dict:Dict[int,dict] = { + ''' + 16000:{ + nfft: 512, + hop_size: 256, + mel_size: 80, + frequency_min: 0, + frequency_max: float(sample_rate // 2) + }, + 44100:{ + nfft: 1024, + hop_size: 512, + mel_size: 128, + frequency_min: 0, + frequency_max: float(sample_rate // 2) + } + ''' + } + spec_config_of_sr:dict = UtilAudioMelSpec.get_default_mel_spec_config(sample_rate = sample_rate) if sample_rate not in spec_config_of_sr_dict else spec_config_of_sr_dict[sample_rate] + + self.util_mel = UtilAudioMelSpec(nfft = spec_config_of_sr['nfft'] if nfft is None else nfft, + hop_size = spec_config_of_sr['hop_size'] if hop_size is None else hop_size, + sample_rate = sample_rate, + mel_size = spec_config_of_sr['mel_size'] if mel_size is None else mel_size, + frequency_min = spec_config_of_sr['frequency_min'] if frequency_min is None else frequency_min, + frequency_max = spec_config_of_sr['frequency_max'] if frequency_max is None else frequency_max) + + def get_spec_metrics_from_audio( + self, + pred, #linear scale spectrogram [time] + target, + metric_list:list = ['lsd','ssim','sispnr', 'l1', 'l2'] + ) -> Dict[str,float]: + + source_spec_dict = self.get_spec_dict_of_audio(pred) + target_spec_dict = self.get_spec_dict_of_audio(target) + + metric_dict = dict() + for spec_name in source_spec_dict: + if 'lsd' in metric_list: + metric_dict[f'lsd_{spec_name}'] = MetricVoice.get_lsd_from_spec(source_spec_dict[spec_name],target_spec_dict[spec_name]) + if 'ssim' in metric_list: + metric_dict[f'ssim_{spec_name}'] = self.get_ssim(source_spec_dict[spec_name], target_spec_dict[spec_name]) + + linear_spec_name = list(source_spec_dict.keys()) + for spec_name in linear_spec_name: + source_spec_dict[f'{spec_name}_log'] = np.log10(np.clip(source_spec_dict[spec_name], a_min=1e-8, a_max=None)) + target_spec_dict[f'{spec_name}_log'] = np.log10(np.clip(target_spec_dict[spec_name], a_min=1e-8, a_max=None)) + + for spec_name in source_spec_dict: + if 'l1' in metric_list: metric_dict[f'l1_{spec_name}'] = float(np.mean(np.abs(source_spec_dict[spec_name] - target_spec_dict[spec_name]))) + if 'l2' in metric_list: metric_dict[f'l2_{spec_name}'] = float(np.mean(np.square(source_spec_dict[spec_name] - target_spec_dict[spec_name]))) + if 'sispnr' in metric_list: metric_dict[f'sispnr_{spec_name}'] = MetricVoice.get_sispnr(torch.from_numpy(source_spec_dict[spec_name]),torch.from_numpy(target_spec_dict[spec_name])) + + return metric_dict + + + def get_lsd_from_audio(self, + pred, # [time] + target # [time] + ) -> Dict[str,float] : + pred_spec_dict = self.get_spec_dict_of_audio(pred) + target_spec_dict = self.get_spec_dict_of_audio(target) + lsd_dict = dict() + for spec_name in pred_spec_dict: + lsd_dict[spec_name] = MetricVoice.get_lsd_from_spec(pred_spec_dict[spec_name],target_spec_dict[spec_name]) + return lsd_dict + + def get_spec_dict_of_audio(self,audio): + spectrogram_mag = self.util_mel.stft_torch(audio)['mag'].float() + mel_spec = self.util_mel.spec_to_mel_spec(spectrogram_mag) + return {'spec_mag':spectrogram_mag.squeeze().detach().cpu().numpy(), 'mel': mel_spec.squeeze().detach().cpu().numpy()} + + @staticmethod + def get_lsd_from_spec(pred, #linear scale spectrogram [freq, time] + target, + eps = 1e-12): + #log_spectral_distance + # in non-log scale + lsd = ((target + eps)**2)/((pred + eps)**2) + lsd = lsd + eps + lsd = np.log10(lsd)**2 #torch.log10((target**2/((source + eps)**2)) + eps)**2 + lsd = np.mean(np.mean(lsd,axis=0)**0.5,axis=0) #torch.mean(torch.mean(lsd,dim=3)**0.5,dim=2) + return float(lsd) + + @staticmethod + def get_si_sdr(source, target): + alpha = np.dot(target, source)/np.linalg.norm(source)**2 + sdr = 10*np.log10(np.linalg.norm(alpha*source)**2/np.linalg.norm( + alpha*source - target)**2) + return sdr + + @staticmethod + def get_pesq(source:ndarray, #[time] + target:ndarray, #[time] + sample_rate:int = [8000,16000][1], + band:str = ['wide-band','narrow-band'][0]): + assert (sample_rate in [8000,16000]), f'sample rate must be either 8000 or 16000. current sample rate {sample_rate}' + if (sample_rate == 16000 and band == 'narrow-band'): print('Warning: narrowband (nb) mode only when sampling rate is 8000Hz') + if band == 'wide-band': + return pesq(sample_rate, target, source, 'wb') + else: + return pesq(sample_rate, target, source, 'nb') + + @staticmethod + def get_mcd(source:ndarray, #[time] + target:ndarray, #[time] + sample_rate:int, + frame_period=5): + cost_function = MetricVoice.dB_distance + mgc_source = MetricVoice.get_mgc(source, sample_rate, frame_period) + mgc_target = MetricVoice.get_mgc(target, sample_rate, frame_period) + + length = min(mgc_source.shape[0], mgc_target.shape[0]) + mgc_source = mgc_source[:length] + mgc_target = mgc_target[:length] + + mcd, _ = fastdtw(mgc_source[..., 1:], mgc_target[..., 1:], dist=cost_function) + mcd = mcd/length + + return float(mcd), length + + @staticmethod + def get_mgc(audio, sample_rate, frame_period, fft_size=512, mcep_size=60, alpha=0.65): + if isinstance(audio, Tensor): + if audio.ndim > 1: + audio = audio[0] + + audio = audio.numpy() + + _, sp, _ = pw.wav2world( + audio.astype(np.double), fs=sample_rate, frame_period=frame_period, fft_size=fft_size) + mgc = pysptk.sptk.mcep( + sp, order=mcep_size, alpha=alpha, maxiter=0, etype=1, eps=1.0E-8, min_det=0.0, itype=3) + + return mgc + + @staticmethod + def dB_distance(source, target): + dB_const = 10.0/np.log(10.0)*np.sqrt(2.0) + distance = source - target + + return dB_const*np.sqrt(np.inner(distance, distance)) + + @staticmethod + def get_sispnr(source, target, eps = 1e-12): + # scale_invariant_spectrogram_to_noise_ratio + # in log scale + output, target = UtilAudio.energy_unify(source, target) + noise = output - target + # print(pow_p_norm(target) , pow_p_norm(noise), pow_p_norm(target) / (pow_p_norm(noise) + EPS)) + sp_loss = 10 * torch.log10((UtilAudio.pow_p_norm(target) / (UtilAudio.pow_p_norm(noise) + eps) + eps)) + return float(sp_loss) + + @staticmethod + def get_ssim(source, target, data_range=None): + if data_range is None: + data_range = max(source.max(), target.max()) - min(source.min(), target.min()) + return float(ssim(source, target, win_size=7, data_range=data_range)) + + +''' + + + + def get_sdr_torchmetrics(self,pred_audio:Union[Tensor,ndarray], target_audio:Union[Tensor,ndarray]) -> dict: + result_dict = dict() + audio_spec_tensor_dict:dict = self.get_audio_and_spec_tensor_pred_target_dict(pred_audio, target_audio) + for data_type in audio_spec_tensor_dict["pred"]: + sdr = SignalDistortionRatio() + result_dict[f"sdr_torchmetrics_{data_type}"] = float(sdr(audio_spec_tensor_dict["pred"][data_type].clone(),audio_spec_tensor_dict["target"][data_type].clone())) + + return result_dict + + + +def get_f0(audio, sample_rate, frame_period=5, method='dio'): + if isinstance(audio, torch.Tensor): + if audio.ndim > 1: + audio = audio[0] + + audio = audio.numpy() + + hop_size = int(frame_period*sample_rate/1000) + if method == 'dio': + f0, _ = pw.dio(audio.astype(np.double), sample_rate, frame_period=frame_period) + elif method == 'harvest': + f0, _ = pw.harvest(audio.astype(np.double), sample_rate, frame_period=frame_period) + elif method == 'swipe': + f0 = pysptk.sptk.swipe(audio.astype(np.double), sample_rate, hopsize=hop_size) + elif method == 'rapt': + f0 = pysptk.sptk.rapt(audio.astype(np.double), sample_rate, hopsize=hop_size) + else: + raise ValueError(f'No such f0 extract method, {method}.') + + f0 = torch.from_numpy(f0) + vuv = 1*(f0 != 0.0) + + return f0, vuv + + +def get_f0_rmse(source, target, sample_rate, frame_period=5, method='dio'): + length = min(source.shape[-1], target.shape[-1]) + + source_f0, source_v = get_f0(source[...,:length], sample_rate, frame_period, method) + target_f0, target_v = get_f0(target[...,:length], sample_rate, frame_period, method) + + source_uv = 1 - source_v + target_uv = 1 - target_v + tp_mask = source_v*target_v + + length = tp_mask.sum().item() + + f0_rmse = 1200.0*torch.abs(torch.log2(target_f0 + target_uv) - torch.log2(source_f0 + source_uv)) + f0_rmse = tp_mask*f0_rmse + f0_rmse = f0_rmse.sum()/length + + return f0_rmse.item(), length +''' \ No newline at end of file diff --git a/TorchJaekwon/Evaluater/Package/pysepm/__init__.py b/TorchJaekwon/Evaluater/Package/pysepm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc724e1591007963cb29ef95b2bf23571806dcb --- /dev/null +++ b/TorchJaekwon/Evaluater/Package/pysepm/__init__.py @@ -0,0 +1,8 @@ +__version__ = '0.1' + + +from .qualityMeasures import fwSNRseg,SNRseg,llr,wss,composite,pesq,cepstrum_distance +from .intelligibilityMeasures import stoi,csii,ncm +from .reverberationMeasures import srr_seg,bsd,srmr + + diff --git a/TorchJaekwon/Evaluater/Package/pysepm/intelligibilityMeasures.py b/TorchJaekwon/Evaluater/Package/pysepm/intelligibilityMeasures.py new file mode 100644 index 0000000000000000000000000000000000000000..0afd51e3f315a8ffea479efefae54d379e7f601c --- /dev/null +++ b/TorchJaekwon/Evaluater/Package/pysepm/intelligibilityMeasures.py @@ -0,0 +1,281 @@ +from scipy.signal import stft,resample,butter,lfilter,hilbert +from scipy.interpolate import interp1d +try: from pystoi import stoi as pystoi # https://github.com/mpariente/pystoi +except: + print('pystoi not found') + pystoi = None +import numpy as np + +from .util import extract_overlapped_windows,resample_matlab_like + +stoi = pystoi + +def fwseg_noise(clean_speech, processed_speech,fs,frameLen=0.03, overlap=0.75): + + clean_length = len(clean_speech) + processed_length = len(processed_speech) + rms_all=np.linalg.norm(clean_speech)/np.sqrt(processed_length) + + winlength = round(frameLen*fs) #window length in samples + skiprate = int(np.floor((1-overlap)*frameLen*fs)) #window skip in samples + max_freq = fs/2 #maximum bandwidth + num_crit = 16 # number of critical bands + n_fft = int(2**np.ceil(np.log2(2*winlength))) + n_fftby2 = int(n_fft/2) + + cent_freq=np.zeros((num_crit,)) + bandwidth=np.zeros((num_crit,)) + + # ---------------------------------------------------------------------- + # Critical Band Filter Definitions (Center Frequency and Bandwidths in Hz) + # ---------------------------------------------------------------------- + cent_freq[0] = 150.0000; bandwidth[0] = 100.0000; + cent_freq[1] = 250.000; bandwidth[1] = 100.0000; + cent_freq[2] = 350.000; bandwidth[2] = 100.0000; + cent_freq[3] = 450.000; bandwidth[3] = 110.0000; + cent_freq[4] = 570.000; bandwidth[4] = 120.0000; + cent_freq[5] = 700.000; bandwidth[5] = 140.0000; + cent_freq[6] = 840.000; bandwidth[6] = 150.0000; + cent_freq[7] = 1000.000; bandwidth[7] = 160.000; + cent_freq[8] = 1170.000; bandwidth[8] = 190.000; + cent_freq[9] = 1370.000; bandwidth[9] = 210.000; + cent_freq[10] = 1600.000; bandwidth[10]= 240.000; + cent_freq[11] = 1850.000; bandwidth[11]= 280.000; + cent_freq[12] = 2150.000; bandwidth[12]= 320.000; + cent_freq[13] = 2500.000; bandwidth[13]= 380.000; + cent_freq[14] = 2900.000; bandwidth[14]= 450.000; + cent_freq[15] = 3400.000; bandwidth[15]= 550.000; + + Weight=np.array([0.0192,0.0312,0.0926,0.1031,0.0735,0.0611,0.0495,0.044,0.044,0.049,0.0486,0.0493, 0.049,0.0547,0.0555,0.0493]) + + # ---------------------------------------------------------------------- + # Set up the critical band filters. Note here that Gaussianly shaped + # filters are used. Also, the sum of the filter weights are equivalent + # for each critical band filter. Filter less than -30 dB and set to + # zero. + # ---------------------------------------------------------------------- + + all_f0=np.zeros((num_crit,)) + crit_filter=np.zeros((num_crit,int(n_fftby2))) + g = np.zeros((num_crit,n_fftby2)) + + b = bandwidth; + q = cent_freq/1000; + p = 4*1000*q/b; # Eq. (7) + + #15.625=4000/256 + j = np.arange(0,n_fftby2) + + for i in range(num_crit): + g[i,:]=np.abs(1-j*(fs/n_fft)/(q[i]*1000));# Eq. (9) + crit_filter[i,:] = (1+p[i]*g[i,:])*np.exp(-p[i]*g[i,:]);# Eq. (8) + + num_frames = int(clean_length/skiprate-(winlength/skiprate)); # number of frames + start = 0 # starting sample + hannWin = 0.5*(1-np.cos(2*np.pi*np.arange(1,winlength+1)/(winlength+1))) + + f,t,clean_spec=stft(clean_speech[0:int(num_frames)*skiprate+int(winlength-skiprate)], fs=fs, window=hannWin, nperseg=winlength, noverlap=winlength-skiprate, nfft=n_fft, detrend=False, return_onesided=False, boundary=None, padded=False) + f,t,processed_spec=stft(processed_speech[0:int(num_frames)*skiprate+int(winlength-skiprate)], fs=fs, window=hannWin, nperseg=winlength, noverlap=winlength-skiprate, nfft=n_fft, detrend=False, return_onesided=False, boundary=None, padded=False) + + clean_frames = extract_overlapped_windows(clean_speech[0:int(num_frames)*skiprate+int(winlength-skiprate)],winlength,winlength-skiprate,None) + rms_seg = np.linalg.norm(clean_frames,axis=-1)/np.sqrt(winlength); + rms_db = 20*np.log10(rms_seg/rms_all); + #-------------------------------------------------------------- + # cal r2_high,r2_middle,r2_low + highInd = np.where(rms_db>=0) + highInd = highInd[0] + middleInd = np.where((rms_db>=-10) & (rms_db<0)) + middleInd = middleInd[0] + lowInd = np.where(rms_db<-10) + lowInd = lowInd[0] + + num_high = np.sum(clean_spec[0:n_fftby2,highInd]*np.conj(processed_spec[0:n_fftby2,highInd]),axis=-1) + denx_high = np.sum(np.abs(clean_spec[0:n_fftby2,highInd])**2,axis=-1) + deny_high = np.sum(np.abs(processed_spec[0:n_fftby2,highInd])**2,axis=-1); + + num_middle = np.sum(clean_spec[0:n_fftby2,middleInd]*np.conj(processed_spec[0:n_fftby2,middleInd]),axis=-1) + denx_middle = np.sum(np.abs(clean_spec[0:n_fftby2,middleInd])**2,axis=-1) + deny_middle = np.sum(np.abs(processed_spec[0:n_fftby2,middleInd])**2,axis=-1); + + num_low = np.sum(clean_spec[0:n_fftby2,lowInd]*np.conj(processed_spec[0:n_fftby2,lowInd]),axis=-1) + denx_low = np.sum(np.abs(clean_spec[0:n_fftby2,lowInd])**2,axis=-1) + deny_low = np.sum(np.abs(processed_spec[0:n_fftby2,lowInd])**2,axis=-1); + + num2_high = np.abs(num_high)**2; + r2_high = num2_high/(denx_high*deny_high); + + num2_middle = np.abs(num_middle)**2; + r2_middle = num2_middle/(denx_middle*deny_middle); + + num2_low = np.abs(num_low)**2; + r2_low = num2_low/(denx_low*deny_low); + #-------------------------------------------------------------- + # cal distortion frame by frame + + clean_spec = np.abs(clean_spec); + processed_spec = np.abs(processed_spec)**2; + + W_freq=Weight + + processed_energy = crit_filter.dot((processed_spec[0:n_fftby2,highInd].T*r2_high).T) + de_processed_energy= crit_filter.dot((processed_spec[0:n_fftby2,highInd].T*(1-r2_high)).T) + SDR = processed_energy/de_processed_energy;# Eq 13 in Kates (2005) + SDRlog=10*np.log10(SDR); + SDRlog_lim = SDRlog + SDRlog_lim[SDRlog_lim<-15]=-15 + SDRlog_lim[SDRlog_lim>15]=15 # limit between [-15, 15] + Tjm = (SDRlog_lim+15)/30; + distortionh = W_freq.dot(Tjm)/np.sum(W_freq,axis=0) + distortionh[distortionh<0]=0 + + + processed_energy = crit_filter.dot((processed_spec[0:n_fftby2,middleInd].T*r2_middle).T) + de_processed_energy= crit_filter.dot((processed_spec[0:n_fftby2,middleInd].T*(1-r2_middle)).T) + SDR = processed_energy/de_processed_energy;# Eq 13 in Kates (2005) + SDRlog=10*np.log10(SDR); + SDRlog_lim = SDRlog + SDRlog_lim[SDRlog_lim<-15]=-15 + SDRlog_lim[SDRlog_lim>15]=15 # limit between [-15, 15] + Tjm = (SDRlog_lim+15)/30; + distortionm = W_freq.dot(Tjm)/np.sum(W_freq,axis=0) + distortionm[distortionm<0]=0 + + processed_energy = crit_filter.dot((processed_spec[0:n_fftby2,lowInd].T*r2_low).T) + de_processed_energy= crit_filter.dot((processed_spec[0:n_fftby2,lowInd].T*(1-r2_low)).T) + SDR = processed_energy/de_processed_energy;# Eq 13 in Kates (2005) + SDRlog=10*np.log10(SDR); + SDRlog_lim = SDRlog + SDRlog_lim[SDRlog_lim<-15]=-15 + SDRlog_lim[SDRlog_lim>15]=15 # limit between [-15, 15] + Tjm = (SDRlog_lim+15)/30; + distortionl = W_freq.dot(Tjm)/np.sum(W_freq,axis=0) + distortionl[distortionl<0]=0 + + return distortionh,distortionm,distortionl + + +def csii(clean_speech, processed_speech,sample_rate): + sampleLen= min(len( clean_speech), len( processed_speech)) + clean_speech= clean_speech[0: sampleLen] + processed_speech= processed_speech[0: sampleLen] + vec_CSIIh,vec_CSIIm,vec_CSIIl = fwseg_noise(clean_speech, processed_speech, sample_rate) + + CSIIh=np.mean(vec_CSIIh) + CSIIm=np.mean(vec_CSIIm) + CSIIl=np.mean(vec_CSIIl) + return CSIIh,CSIIm,CSIIl + + + +def get_band(M,Fs): + # This function sets the bandpass filter band edges. + # It assumes that the sampling frequency is 8000 Hz. + A = 165 + a = 2.1 + K = 1 + L = 35 + CF = 300; + x_100 = (L/a)*np.log10(CF/A + K) + CF = Fs/2-600 + x_8000 = (L/a)*np.log10(CF/A + K); + LX = x_8000 - x_100 + x_step = LX / M + x = np.arange(x_100,x_8000+x_step+1e-20,x_step) + if len(x) == M: + np.append(x,x_8000) + + BAND = A*(10**(a*x/L) - K) + return BAND + +def get_ansis(BAND): + fcenter=(BAND[0:-1]+BAND[1:])/2; + + # Data from Table B.1 in "ANSI (1997). S3.5–1997 Methods for Calculation of the Speech Intelligibility + # Index. New York: American National Standards Institute." + f=np.array([150,250,350,450,570,700,840,1000,1170,1370,1600,1850,2150,2500,2900,3400,4000,4800,5800,7000,8500]) + BIF=np.array([0.0192,0.0312,0.0926,0.1031,0.0735,0.0611,0.0495,0.0440,0.0440,0.0490,0.0486,0.0493,0.0490,0.0547,0.0555,0.0493,0.0359,0.0387,0.0256,0.0219,0.0043]) + f_ANSI = interp1d(f,BIF) + ANSIs= f_ANSI(fcenter); + return fcenter,ANSIs + + +def ncm(clean_speech,processed_speech,fs): + + if fs != 8000 and fs != 16000: + raise ValueError('fs must be either 8 kHz or 16 kHz') + + + + x= clean_speech # clean signal + y= processed_speech # noisy signal + F_SIGNAL = fs + + F_ENVELOPE = 32 # limits modulations to 0 Ly: + x = x[0:Ly] + if Ly > Lx: + y = y[0:Lx] + + Lx = len(x); + Ly = len(y); + + X_BANDS = np.zeros((Lx,M_CHANNELS)) + Y_BANDS = np.zeros((Lx,M_CHANNELS)) + + # DESIGN BANDPASS FILTERS + for a in range(M_CHANNELS): + B_bp,A_bp = butter( 4 , np.array([BAND[a],BAND[a+1]])*(2/F_SIGNAL),btype='bandpass') + X_BANDS[:,a] = lfilter( B_bp , A_bp , x ) + Y_BANDS[:,a] = lfilter( B_bp , A_bp , y ) + + gcd = np.gcd(F_SIGNAL, F_ENVELOPE) + # CALCULATE HILBERT ENVELOPES, and resample at F_ENVELOPE Hz + analytic_x = hilbert( X_BANDS,axis=0); + X = np.abs( analytic_x ); + #X = resample( X , round(len(x)/F_SIGNAL*F_ENVELOPE)); + X = resample_matlab_like(X,F_ENVELOPE,F_SIGNAL) + analytic_y = hilbert( Y_BANDS,axis=0); + Y = np.abs( analytic_y ); + #Y = resample( Y , round(len(x)/F_SIGNAL*F_ENVELOPE)); + Y = resample_matlab_like(Y,F_ENVELOPE,F_SIGNAL) + ## ---compute weights based on clean signal's rms envelopes----- + # + Ldx, pp=X.shape + p=3 # power exponent - see Eq. 12 + + ro2 = np.zeros((M_CHANNELS,)) + asnr = np.zeros((M_CHANNELS,)) + TI = np.zeros((M_CHANNELS,)) + + for k in range(M_CHANNELS): + x_tmp= X[ :, k] + y_tmp= Y[ :, k] + lambda_x= np.linalg.norm( x_tmp- np.mean( x_tmp))**2 + lambda_y= np.linalg.norm( y_tmp- np.mean( y_tmp))**2 + lambda_xy= np.sum( (x_tmp- np.mean( x_tmp))*(y_tmp- np.mean( y_tmp))) + ro2[k]= (lambda_xy**2)/ (lambda_x*lambda_y) + asnr[k]= 10*np.log10( (ro2[k]+ 1e-20)/ (1- ro2[k]+ 1e-20)); # Eq.9 in [1] + + if asnr[k]< -15: + asnr[k]= -15 + elif asnr[k]> 15: + asnr[k]= 15 + + TI[k]= (asnr[k]+ 15)/ 30 # Eq.10 in [1] + + ncm_val= WEIGHT.dot(TI)/np.sum(WEIGHT) # Eq.11 + return ncm_val diff --git a/TorchJaekwon/Evaluater/Package/pysepm/qualityMeasures.py b/TorchJaekwon/Evaluater/Package/pysepm/qualityMeasures.py new file mode 100644 index 0000000000000000000000000000000000000000..7c6ab22e67c2fe3b9390bb00a1dd4a58b9c010d5 --- /dev/null +++ b/TorchJaekwon/Evaluater/Package/pysepm/qualityMeasures.py @@ -0,0 +1,438 @@ +from scipy.signal import stft +from scipy.linalg import toeplitz +import scipy +try: import pesq as pypesq # https://github.com/ludlows/python-pesq +except: print('Module pesq not found.') +import numpy as np +from numba import jit +from .util import extract_overlapped_windows + +def SNRseg(clean_speech, processed_speech,fs, frameLen=0.03, overlap=0.75): + eps=np.finfo(np.float64).eps + + winlength = round(frameLen*fs) #window length in samples + skiprate = int(np.floor((1-overlap)*frameLen*fs)) #window skip in samples + MIN_SNR = -10 # minimum SNR in dB + MAX_SNR = 35 # maximum SNR in dB + + hannWin=0.5*(1-np.cos(2*np.pi*np.arange(1,winlength+1)/(winlength+1))) + clean_speech_framed=extract_overlapped_windows(clean_speech,winlength,winlength-skiprate,hannWin) + processed_speech_framed=extract_overlapped_windows(processed_speech,winlength,winlength-skiprate,hannWin) + + signal_energy = np.power(clean_speech_framed,2).sum(-1) + noise_energy = np.power(clean_speech_framed-processed_speech_framed,2).sum(-1) + + segmental_snr = 10*np.log10(signal_energy/(noise_energy+eps)+eps) + segmental_snr[segmental_snrMAX_SNR]=MAX_SNR + segmental_snr=segmental_snr[:-1] # remove last frame -> not valid + return np.mean(segmental_snr) + +def fwSNRseg(cleanSig, enhancedSig, fs, frameLen=0.03, overlap=0.75): + if cleanSig.shape!=enhancedSig.shape: + raise ValueError('The two signals do not match!') + eps=np.finfo(np.float64).eps + cleanSig=cleanSig.astype(np.float64)+eps + enhancedSig=enhancedSig.astype(np.float64)+eps + winlength = round(frameLen*fs) #window length in samples + skiprate = int(np.floor((1-overlap)*frameLen*fs)) #window skip in samples + max_freq = fs/2 #maximum bandwidth + num_crit = 25# number of critical bands + n_fft = 2**np.ceil(np.log2(2*winlength)) + n_fftby2 = int(n_fft/2) + gamma=0.2 + + cent_freq=np.zeros((num_crit,)) + bandwidth=np.zeros((num_crit,)) + + cent_freq[0] = 50.0000; bandwidth[0] = 70.0000; + cent_freq[1] = 120.000; bandwidth[1] = 70.0000; + cent_freq[2] = 190.000; bandwidth[2] = 70.0000; + cent_freq[3] = 260.000; bandwidth[3] = 70.0000; + cent_freq[4] = 330.000; bandwidth[4] = 70.0000; + cent_freq[5] = 400.000; bandwidth[5] = 70.0000; + cent_freq[6] = 470.000; bandwidth[6] = 70.0000; + cent_freq[7] = 540.000; bandwidth[7] = 77.3724; + cent_freq[8] = 617.372; bandwidth[8] = 86.0056; + cent_freq[9] = 703.378; bandwidth[9] = 95.3398; + cent_freq[10] = 798.717; bandwidth[10] = 105.411; + cent_freq[11] = 904.128; bandwidth[11] = 116.256; + cent_freq[12] = 1020.38; bandwidth[12] = 127.914; + cent_freq[13] = 1148.30; bandwidth[13] = 140.423; + cent_freq[14] = 1288.72; bandwidth[14] = 153.823; + cent_freq[15] = 1442.54; bandwidth[15] = 168.154; + cent_freq[16] = 1610.70; bandwidth[16] = 183.457; + cent_freq[17] = 1794.16; bandwidth[17] = 199.776; + cent_freq[18] = 1993.93; bandwidth[18] = 217.153; + cent_freq[19] = 2211.08; bandwidth[19] = 235.631; + cent_freq[20] = 2446.71; bandwidth[20] = 255.255; + cent_freq[21] = 2701.97; bandwidth[21] = 276.072; + cent_freq[22] = 2978.04; bandwidth[22] = 298.126; + cent_freq[23] = 3276.17; bandwidth[23] = 321.465; + cent_freq[24] = 3597.63; bandwidth[24] = 346.136; + + + W=np.array([0.003,0.003,0.003,0.007,0.010,0.016,0.016,0.017,0.017,0.022,0.027,0.028,0.030,0.032,0.034,0.035,0.037,0.036,0.036,0.033,0.030,0.029,0.027,0.026, + 0.026]) + + bw_min=bandwidth[0] + min_factor = np.exp (-30.0 / (2.0 * 2.303));# % -30 dB point of filter + + all_f0=np.zeros((num_crit,)) + crit_filter=np.zeros((num_crit,int(n_fftby2))) + j = np.arange(0,n_fftby2) + + + for i in range(num_crit): + f0 = (cent_freq[i] / max_freq) * (n_fftby2) + all_f0[i] = np.floor(f0); + bw = (bandwidth[i] / max_freq) * (n_fftby2); + norm_factor = np.log(bw_min) - np.log(bandwidth[i]); + crit_filter[i,:] = np.exp (-11 *(((j - np.floor(f0))/bw)**2) + norm_factor) + crit_filter[i,:] = crit_filter[i,:]*(crit_filter[i,:] > min_factor) + + num_frames = len(cleanSig)/skiprate-(winlength/skiprate)# number of frames + start = 1 # starting sample + #window = 0.5*(1 - cos(2*pi*(1:winlength).T/(winlength+1))); + + + hannWin=0.5*(1-np.cos(2*np.pi*np.arange(1,winlength+1)/(winlength+1))) + f,t,Zxx=stft(cleanSig[0:int(num_frames)*skiprate+int(winlength-skiprate)], fs=fs, window=hannWin, nperseg=winlength, noverlap=winlength-skiprate, nfft=n_fft, detrend=False, return_onesided=True, boundary=None, padded=False) + clean_spec=np.abs(Zxx) + clean_spec=clean_spec[:-1,:] + clean_spec=(clean_spec/clean_spec.sum(0)) + f,t,Zxx=stft(enhancedSig[0:int(num_frames)*skiprate+int(winlength-skiprate)], fs=fs, window=hannWin, nperseg=winlength, noverlap=winlength-skiprate, nfft=n_fft, detrend=False, return_onesided=True, boundary=None, padded=False) + enh_spec=np.abs(Zxx) + enh_spec=enh_spec[:-1,:] + enh_spec=(enh_spec/enh_spec.sum(0)) + + clean_energy=(crit_filter.dot(clean_spec)) + processed_energy=(crit_filter.dot(enh_spec)) + error_energy=np.power(clean_energy-processed_energy,2) + error_energy[error_energy35]=35 + + return np.mean(distortion) +@jit +def lpcoeff(speech_frame, model_order): + eps=np.finfo(np.float64).eps + # ---------------------------------------------------------- + # (1) Compute Autocorrelation Lags + # ---------------------------------------------------------- + winlength = max(speech_frame.shape) + R = np.zeros((model_order+1,)) + for k in range(model_order+1): + if k==0: + R[k]=np.sum(speech_frame[0:]*speech_frame[0:]) + else: + R[k]=np.sum(speech_frame[0:-k]*speech_frame[k:]) + + + #R=scipy.signal.correlate(speech_frame,speech_frame) + #R=R[len(speech_frame)-1:len(speech_frame)+model_order] + # ---------------------------------------------------------- + # (2) Levinson-Durbin + # ---------------------------------------------------------- + a = np.ones((model_order,)) + a_past = np.ones((model_order,)) + rcoeff = np.zeros((model_order,)) + E = np.zeros((model_order+1,)) + + E[0]=R[0] + + for i in range(0,model_order): + a_past[0:i] = a[0:i] + + sum_term = np.sum(a_past[0:i]*R[i:0:-1]) + + if E[i]==0.0: # prevents zero division error, numba doesn't allow try/except statements + rcoeff[i]= np.inf + else: + rcoeff[i]=(R[i+1] - sum_term) / (E[i]) + + a[i]=rcoeff[i] + #if i==0: + # a[0:i] = a_past[0:i] - rcoeff[i]*np.array([]) + #else: + if i>0: + a[0:i] = a_past[0:i] - rcoeff[i]*a_past[i-1::-1] + + E[i+1]=(1-rcoeff[i]*rcoeff[i])*E[i] + + acorr = R; + refcoeff = rcoeff; + lpparams = np.ones((model_order+1,)) + lpparams[1:] = -a + return(lpparams,R) + +def llr(clean_speech, processed_speech, fs, used_for_composite=False, frameLen=0.03, overlap=0.75): + eps=np.finfo(np.float64).eps + alpha = 0.95 + winlength = round(frameLen*fs) #window length in samples + skiprate = int(np.floor((1-overlap)*frameLen*fs)) #window skip in samples + if fs<10000: + P = 10 # LPC Analysis Order + else: + P = 16 # this could vary depending on sampling frequency. + + hannWin=0.5*(1-np.cos(2*np.pi*np.arange(1,winlength+1)/(winlength+1))) + clean_speech_framed=extract_overlapped_windows(clean_speech+eps,winlength,winlength-skiprate,hannWin) + processed_speech_framed=extract_overlapped_windows(processed_speech+eps,winlength,winlength-skiprate,hannWin) + numFrames=clean_speech_framed.shape[0] + numerators = np.zeros((numFrames-1,)) + denominators = np.zeros((numFrames-1,)) + + for ii in range(numFrames-1): + A_clean,R_clean=lpcoeff(clean_speech_framed[ii,:],P) + A_proc,R_proc=lpcoeff(processed_speech_framed[ii,:],P) + + numerators[ii]=A_proc.dot(toeplitz(R_clean).dot(A_proc.T)) + denominators[ii]=A_clean.dot(toeplitz(R_clean).dot(A_clean.T)) + + + frac=numerators/(denominators) + frac[np.isnan(frac)]=np.inf + frac[frac<=0]=1000 + distortion = np.log(frac) + if not used_for_composite: + distortion[distortion>2]=2 # this line is not in composite measure but in llr matlab implementation of loizou + distortion = np.sort(distortion) + distortion = distortion[:int(round(len(distortion)*alpha))] + return np.mean(distortion) + + +@jit +def find_loc_peaks(slope,energy): + num_crit = len(energy) + + loc_peaks=np.zeros_like(slope) + + for ii in range(len(slope)): + n=ii + if slope[ii]>0: + while ((n 0)): + n=n+1 + loc_peaks[ii]=energy[n-1] + else: + while ((n>=0) and (slope[n] <= 0)): + n=n-1 + loc_peaks[ii]=energy[n+1] + + return loc_peaks + + + +def wss(clean_speech, processed_speech, fs, frameLen=0.03, overlap=0.75): + + Kmax = 20 # value suggested by Klatt, pg 1280 + Klocmax = 1 # value suggested by Klatt, pg 1280 + alpha = 0.95 + if clean_speech.shape!=processed_speech.shape: + raise ValueError('The two signals do not match!') + eps=np.finfo(np.float64).eps + clean_speech=clean_speech.astype(np.float64)+eps + processed_speech=processed_speech.astype(np.float64)+eps + winlength = round(frameLen*fs) #window length in samples + skiprate = int(np.floor((1-overlap)*frameLen*fs)) #window skip in samples + max_freq = fs/2 #maximum bandwidth + num_crit = 25# number of critical bands + n_fft = 2**np.ceil(np.log2(2*winlength)) + n_fftby2 = int(n_fft/2) + + cent_freq=np.zeros((num_crit,)) + bandwidth=np.zeros((num_crit,)) + + cent_freq[0] = 50.0000; bandwidth[0] = 70.0000; + cent_freq[1] = 120.000; bandwidth[1] = 70.0000; + cent_freq[2] = 190.000; bandwidth[2] = 70.0000; + cent_freq[3] = 260.000; bandwidth[3] = 70.0000; + cent_freq[4] = 330.000; bandwidth[4] = 70.0000; + cent_freq[5] = 400.000; bandwidth[5] = 70.0000; + cent_freq[6] = 470.000; bandwidth[6] = 70.0000; + cent_freq[7] = 540.000; bandwidth[7] = 77.3724; + cent_freq[8] = 617.372; bandwidth[8] = 86.0056; + cent_freq[9] = 703.378; bandwidth[9] = 95.3398; + cent_freq[10] = 798.717; bandwidth[10] = 105.411; + cent_freq[11] = 904.128; bandwidth[11] = 116.256; + cent_freq[12] = 1020.38; bandwidth[12] = 127.914; + cent_freq[13] = 1148.30; bandwidth[13] = 140.423; + cent_freq[14] = 1288.72; bandwidth[14] = 153.823; + cent_freq[15] = 1442.54; bandwidth[15] = 168.154; + cent_freq[16] = 1610.70; bandwidth[16] = 183.457; + cent_freq[17] = 1794.16; bandwidth[17] = 199.776; + cent_freq[18] = 1993.93; bandwidth[18] = 217.153; + cent_freq[19] = 2211.08; bandwidth[19] = 235.631; + cent_freq[20] = 2446.71; bandwidth[20] = 255.255; + cent_freq[21] = 2701.97; bandwidth[21] = 276.072; + cent_freq[22] = 2978.04; bandwidth[22] = 298.126; + cent_freq[23] = 3276.17; bandwidth[23] = 321.465; + cent_freq[24] = 3597.63; bandwidth[24] = 346.136; + + + W=np.array([0.003,0.003,0.003,0.007,0.010,0.016,0.016,0.017,0.017,0.022,0.027,0.028,0.030,0.032,0.034,0.035,0.037,0.036,0.036,0.033,0.030,0.029,0.027,0.026, + 0.026]) + + bw_min=bandwidth[0] + min_factor = np.exp (-30.0 / (2.0 * 2.303));# % -30 dB point of filter + + all_f0=np.zeros((num_crit,)) + crit_filter=np.zeros((num_crit,int(n_fftby2))) + j = np.arange(0,n_fftby2) + + + for i in range(num_crit): + f0 = (cent_freq[i] / max_freq) * (n_fftby2) + all_f0[i] = np.floor(f0); + bw = (bandwidth[i] / max_freq) * (n_fftby2); + norm_factor = np.log(bw_min) - np.log(bandwidth[i]); + crit_filter[i,:] = np.exp (-11 *(((j - np.floor(f0))/bw)**2) + norm_factor) + crit_filter[i,:] = crit_filter[i,:]*(crit_filter[i,:] > min_factor) + + num_frames = len(clean_speech)/skiprate-(winlength/skiprate)# number of frames + start = 1 # starting sample + + hannWin=0.5*(1-np.cos(2*np.pi*np.arange(1,winlength+1)/(winlength+1))) + scale = np.sqrt(1.0 / hannWin.sum()**2) + + f,t,Zxx=stft(clean_speech[0:int(num_frames)*skiprate+int(winlength-skiprate)], fs=fs, window=hannWin, nperseg=winlength, noverlap=winlength-skiprate, nfft=n_fft, detrend=False, return_onesided=True, boundary=None, padded=False) + clean_spec=np.power(np.abs(Zxx)/scale,2) + clean_spec=clean_spec[:-1,:] + + f,t,Zxx=stft(processed_speech[0:int(num_frames)*skiprate+int(winlength-skiprate)], fs=fs, window=hannWin, nperseg=winlength, noverlap=winlength-skiprate, nfft=n_fft, detrend=False, return_onesided=True, boundary=None, padded=False) + proc_spec=np.power(np.abs(Zxx)/scale,2) + proc_spec=proc_spec[:-1,:] + + clean_energy=(crit_filter.dot(clean_spec)) + log_clean_energy=10*np.log10(clean_energy) + log_clean_energy[log_clean_energy<-100]=-100 + proc_energy=(crit_filter.dot(proc_spec)) + log_proc_energy=10*np.log10(proc_energy) + log_proc_energy[log_proc_energy<-100]=-100 + + log_clean_energy_slope=np.diff(log_clean_energy,axis=0) + log_proc_energy_slope=np.diff(log_proc_energy,axis=0) + + dBMax_clean = np.max(log_clean_energy,axis=0) + dBMax_processed = np.max(log_proc_energy,axis=0) + + numFrames=log_clean_energy_slope.shape[-1] + + clean_loc_peaks=np.zeros_like(log_clean_energy_slope) + proc_loc_peaks=np.zeros_like(log_proc_energy_slope) + for ii in range(numFrames): + clean_loc_peaks[:,ii]=find_loc_peaks(log_clean_energy_slope[:,ii],log_clean_energy[:,ii]) + proc_loc_peaks[:,ii]=find_loc_peaks(log_proc_energy_slope[:,ii],log_proc_energy[:,ii]) + + + Wmax_clean = Kmax / (Kmax + dBMax_clean - log_clean_energy[:-1,:]) + Wlocmax_clean = Klocmax / ( Klocmax + clean_loc_peaks - log_clean_energy[:-1,:]) + W_clean = Wmax_clean * Wlocmax_clean + + Wmax_proc = Kmax / (Kmax + dBMax_processed - log_proc_energy[:-1]) + Wlocmax_proc = Klocmax / ( Klocmax + proc_loc_peaks - log_proc_energy[:-1,:]) + W_proc = Wmax_proc * Wlocmax_proc + + W = (W_clean + W_proc)/2.0 + + distortion=np.sum(W*(log_clean_energy_slope- log_proc_energy_slope)**2,axis=0) + distortion=distortion/np.sum(W,axis=0) + distortion = np.sort(distortion) + distortion = distortion[:int(round(len(distortion)*alpha))] + return np.mean(distortion) + +def pesq(clean_speech, processed_speech, fs): + if fs == 8000: + mos_lqo = pypesq.pesq(fs,clean_speech, processed_speech, 'nb') + pesq_mos = 46607/14945 - (2000*np.log(1/(mos_lqo/4 - 999/4000) - 1))/2989#0.999 + ( 4.999-0.999 ) / ( 1+np.exp(-1.4945*pesq_mos+4.6607) ) + elif fs == 16000: + mos_lqo = pypesq.pesq(fs,clean_speech, processed_speech, 'wb') + pesq_mos = np.NaN + else: + raise ValueError('fs must be either 8 kHz or 16 kHz') + + return pesq_mos,mos_lqo + + +def composite(clean_speech, processed_speech, fs): + wss_dist=wss(clean_speech, processed_speech, fs) + llr_mean=llr(clean_speech, processed_speech, fs,used_for_composite=True) + segSNR=SNRseg(clean_speech, processed_speech, fs) + pesq_mos,mos_lqo = pesq(clean_speech, processed_speech,fs) + + if fs >= 16e3: + used_pesq_val = mos_lqo + else: + used_pesq_val = pesq_mos + + Csig = 3.093 - 1.029*llr_mean + 0.603*used_pesq_val-0.009*wss_dist + Csig = np.max((1,Csig)) + Csig = np.min((5, Csig)) # limit values to [1, 5] + Cbak = 1.634 + 0.478 *used_pesq_val - 0.007*wss_dist + 0.063*segSNR + Cbak = np.max((1, Cbak)) + Cbak = np.min((5,Cbak)) # limit values to [1, 5] + Covl = 1.594 + 0.805*used_pesq_val - 0.512*llr_mean - 0.007*wss_dist + Covl = np.max((1, Covl)) + Covl = np.min((5, Covl)) # limit values to [1, 5] + return Csig,Cbak,Covl + +@jit +def lpc2cep(a): + # + # converts prediction to cepstrum coefficients + # + # Author: Philipos C. Loizou + + M=len(a); + cep=np.zeros((M-1,)); + + cep[0]=-a[1] + + for k in range(2,M): + ix=np.arange(1,k) + vec1=cep[ix-1]*a[k-1:0:-1]*(ix) + cep[k-1]=-(a[k]+np.sum(vec1)/k); + return cep + + +def cepstrum_distance(clean_speech, processed_speech, fs, frameLen=0.03, overlap=0.75): + + + clean_length = len(clean_speech) + processed_length = len(processed_speech) + + winlength = round(frameLen*fs) #window length in samples + skiprate = int(np.floor((1-overlap)*frameLen*fs)) #window skip in samples + + if fs<10000: + P = 10 # LPC Analysis Order + else: + P=16; # this could vary depending on sampling frequency. + + C=10*np.sqrt(2)/np.log(10) + + numFrames = int(clean_length/skiprate-(winlength/skiprate)); # number of frames + + hannWin=0.5*(1-np.cos(2*np.pi*np.arange(1,winlength+1)/(winlength+1))) + clean_speech_framed=extract_overlapped_windows(clean_speech[0:int(numFrames)*skiprate+int(winlength-skiprate)],winlength,winlength-skiprate,hannWin) + processed_speech_framed=extract_overlapped_windows(processed_speech[0:int(numFrames)*skiprate+int(winlength-skiprate)],winlength,winlength-skiprate,hannWin) + distortion = np.zeros((numFrames,)) + + for ii in range(numFrames): + A_clean,R_clean=lpcoeff(clean_speech_framed[ii,:],P) + A_proc,R_proc=lpcoeff(processed_speech_framed[ii,:],P) + + C_clean=lpc2cep(A_clean) + C_processed=lpc2cep(A_proc) + distortion[ii] = min((10,C*np.linalg.norm(C_clean-C_processed))) + + IS_dist = distortion + alpha=0.95 + IS_len= round( len( IS_dist)* alpha) + IS = np.sort(IS_dist) + cep_mean= np.mean( IS[ 0: IS_len]) + return cep_mean \ No newline at end of file diff --git a/TorchJaekwon/Evaluater/Package/pysepm/reverberationMeasures.py b/TorchJaekwon/Evaluater/Package/pysepm/reverberationMeasures.py new file mode 100644 index 0000000000000000000000000000000000000000..0dc1231d47547b8fd82f8050d384b2d6f76acdb0 --- /dev/null +++ b/TorchJaekwon/Evaluater/Package/pysepm/reverberationMeasures.py @@ -0,0 +1,118 @@ +from scipy.signal import resample,stft +import scipy +try: import srmrpy #https://github.com/jfsantos/SRMRpy +except: print('srmrpy not installed') +import numpy as np +from .qualityMeasures import SNRseg + +def srr_seg(clean_speech, processed_speech,fs): + return SNRseg(clean_speech, processed_speech,fs) + + +def srmr(speech,fs, n_cochlear_filters=23, low_freq=125, min_cf=4, max_cf=128, fast=False, norm=False): + if fs == 8000: + srmRatio,energy=srmrpy.srmr(speech, fs, n_cochlear_filters=n_cochlear_filters, low_freq=low_freq, min_cf=min_cf, max_cf=max_cf, fast=fast, norm=norm) + return srmRatio + + elif fs == 16000: + srmRatio,energy=srmrpy.srmr(speech, fs, n_cochlear_filters=n_cochlear_filters, low_freq=low_freq, min_cf=min_cf, max_cf=max_cf, fast=fast, norm=norm) + return srmRatio + else: + numSamples=round(len(speech)/fs*16000) + fs = 16000 + srmRatio,energy=srmrpy.srmr(resample(speech, numSamples), fs, n_cochlear_filters=n_cochlear_filters, low_freq=low_freq, min_cf=min_cf, max_cf=max_cf, fast=fast, norm=norm) + return srmRatio + + +def hz_to_bark(freqs_hz): + freqs_hz = np.asanyarray([freqs_hz]) + barks = (26.81*freqs_hz)/(1960+freqs_hz)-0.53 + barks[barks<2]=barks[barks<2]+0.15*(2-barks[barks<2]) + barks[barks>20.1]=barks[barks>20.1]+0.22*(barks[barks>20.1]-20.1) + return np.squeeze(barks) + +def bark_to_hz(barks): + barks = barks.copy() + barks = np.asanyarray([barks]) + barks[barks<2]=(barks[barks<2]-0.3)/0.85 + barks[barks>20.1]=(barks[barks>20.1]+4.422)/1.22 + freqs_hz = 1960 * (barks+0.53)/(26.28-barks) + return np.squeeze(freqs_hz) + +def bark_frequencies(n_barks=128, fmin=0.0, fmax=11025.0): + # 'Center freqs' of bark bands - uniformly spaced between limits + min_bark = hz_to_bark(fmin) + max_bark = hz_to_bark(fmax) + + barks = np.linspace(min_bark, max_bark, n_barks) + + return bark_to_hz(barks) + +def barks(fs, n_fft, n_barks=128, fmin=0.0, fmax=None, norm='area', dtype=np.float32): + + if fmax is None: + fmax = float(fs) / 2 + + + # Initialize the weights + n_barks = int(n_barks) + weights = np.zeros((n_barks, int(1 + n_fft // 2)), dtype=dtype) + + # Center freqs of each FFT bin + fftfreqs = np.linspace(0,float(fs) / 2,int(1 + n_fft//2), endpoint=True) + + # 'Center freqs' of mel bands - uniformly spaced between limits + bark_f = bark_frequencies(n_barks + 2, fmin=fmin, fmax=fmax) + + fdiff = np.diff(bark_f) + ramps = np.subtract.outer(bark_f, fftfreqs) + + for i in range(n_barks): + # lower and upper slopes for all bins + lower = -ramps[i] / fdiff[i] + upper = ramps[i+2] / fdiff[i+1] + + # .. then intersect them with each other and zero + weights[i] = np.maximum(0, np.minimum(lower, upper)) + + if norm in (1, 'area'): + weightsPerBand=np.sum(weights,1); + for i in range(weights.shape[0]): + weights[i,:]=weights[i,:]/weightsPerBand[i] + return weights + +def bsd(clean_speech, processed_speech, fs, frameLen=0.03, overlap=0.75): + + pre_emphasis_coeff = 0.95 + b = np.array([1]) + a = np.array([1,pre_emphasis_coeff]) + clean_speech = scipy.signal.lfilter(b,a,clean_speech) + processed_speech = scipy.signal.lfilter(b,a,processed_speech) + + winlength = round(frameLen*fs) #window length in samples + skiprate = int(np.floor((1-overlap)*frameLen*fs)) #window skip in samples + max_freq = fs/2 #maximum bandwidth + n_fft = 2**np.ceil(np.log2(2*winlength)) + n_fftby2 = int(n_fft/2) + num_frames = len(clean_speech)/skiprate-(winlength/skiprate)# number of frames + + hannWin=scipy.signal.windows.hann(winlength)#0.5*(1-np.cos(2*np.pi*np.arange(1,winlength+1)/(winlength+1))) + f,t,Zxx=stft(clean_speech[0:int(num_frames)*skiprate+int(winlength-skiprate)], fs=fs, window=hannWin, nperseg=winlength, noverlap=winlength-skiprate, nfft=n_fft, detrend=False, return_onesided=True, boundary=None, padded=False) + clean_power_spec=np.square(np.sum(hannWin)*np.abs(Zxx)) + f,t,Zxx=stft(processed_speech[0:int(num_frames)*skiprate+int(winlength-skiprate)], fs=fs, window=hannWin, nperseg=winlength, noverlap=winlength-skiprate, nfft=n_fft, detrend=False, return_onesided=True, boundary=None, padded=False) + enh_power_spec=np.square(np.sum(hannWin)*np.abs(Zxx)) + + bark_filt = barks(fs, n_fft, n_barks=32) + clean_power_spec_bark= np.dot(bark_filt,clean_power_spec) + enh_power_spec_bark= np.dot(bark_filt,enh_power_spec) + + clean_power_spec_bark_2=np.square(clean_power_spec_bark) + diff_power_spec_2 = np.square(clean_power_spec_bark-enh_power_spec_bark) + + bsd = np.mean(np.sum(diff_power_spec_2,axis=0)/np.sum(clean_power_spec_bark_2,axis=0)) + return bsd + + + + + diff --git a/TorchJaekwon/Evaluater/Package/pysepm/util.py b/TorchJaekwon/Evaluater/Package/pysepm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..608ba2012d779ddff298a018b33ebabc7babb9b3 --- /dev/null +++ b/TorchJaekwon/Evaluater/Package/pysepm/util.py @@ -0,0 +1,55 @@ +import numpy as np +from scipy.signal import firls,kaiser,upfirdn +from fractions import Fraction + +def extract_overlapped_windows(x,nperseg,noverlap,window=None): + # source: https://github.com/scipy/scipy/blob/v1.2.1/scipy/signal/spectral.py + step = nperseg - noverlap + shape = x.shape[:-1]+((x.shape[-1]-noverlap)//step, nperseg) + strides = x.strides[:-1]+(step*x.strides[-1], x.strides[-1]) + result = np.lib.stride_tricks.as_strided(x, shape=shape, + strides=strides) + if window is not None: + result = window * result + return result + +def resample_matlab_like(x_orig,p,q): + if len(x_orig.shape)>2: + raise ValueError('x must be a vector or 2d matrix') + + if x_orig.shape[0] Optional[str]: + root_path_list:list = [root_path] + root_path_list.append(root_path.replace("./",f'{TORCH_JAEKWON_PATH}/')) + + for root_path in root_path_list: + for root,dirs,files in os.walk(root_path): + if len(files) > 0: + for file in files: + if os.path.splitext(file)[0] == module_name: + if TORCH_JAEKWON_PATH in root: + torch_jaekwon_parent_path:str = '/'.join(TORCH_JAEKWON_PATH.split('/')[:-1]) + return f'{root}/{os.path.splitext(file)[0]}'.replace(torch_jaekwon_parent_path+'/','').replace("/",".") + else: + return f'{root}/{os.path.splitext(file)[0]}'.replace("./","").replace("/",".") + return None + + @staticmethod + def get_module_class(root_path:str,module_name:str): + module_path:str = GetModule.get_import_path_of_module(root_path,module_name) + module_from = importlib.import_module(module_path) + return getattr(module_from,module_name) + + @staticmethod + def get_model( + model_name:str, + root_path:str = './Model' + ) -> nn.Module: + module_file_path:str = GetModule.get_import_path_of_module(root_path, model_name) + file_module = importlib.import_module(module_file_path) + class_module = getattr(file_module,model_name) + argument_getter:Callable[[],dict] = getattr(class_module,'get_argument_of_this_model',lambda: dict()) + model_parameter:dict = argument_getter() + if len(model_parameter) == 0: + model_parameter = HParams().model.class_meta_dict.get(model_name,{}) + if not model_parameter: + model_parameter = getattr(HParams().model,model_name,dict()) + if not model_parameter: print(f'''GetModule: Model [{model_name}] doesn't have changed arguments''') + model:nn.Module = class_module(**model_parameter) + return model \ No newline at end of file diff --git a/TorchJaekwon/Inference/Inferencer/Inferencer.py b/TorchJaekwon/Inference/Inferencer/Inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..965752a3cc4b6a1051f79297a3e9416324dfa954 --- /dev/null +++ b/TorchJaekwon/Inference/Inferencer/Inferencer.py @@ -0,0 +1,140 @@ +#type +from typing import List, Tuple,Union, Literal +from torch import Tensor +#package +import os +import torch +import torch.nn as nn +from tqdm import tqdm +#torchjaekwon +from TorchJaekwon.GetModule import GetModule +from TorchJaekwon.Util.UtilData import UtilData +#internal + +class Inferencer(): + def __init__(self, + output_dir:str, + experiment_name:str, + model:Union[nn.Module,object], + model_class_name:str, + set_type:Literal[ 'single', 'dir', 'testset' ], + set_meta_dict: dict, + device:torch.device, + ) -> None: + self.output_dir:str = output_dir + self.experiment_name:str = experiment_name + + self.device:torch.device = device + + assert model_class_name is not None or model is not None, "model_class_name or model must be not None" + self.model:Union[nn.Module,object] = self.get_model(model_class_name) if model is None else model + self.shared_dir_name:str = '0shared0' + + self.set_type:Literal[ 'single', 'dir', 'testset' ] = set_type + self.set_meta_dict:dict = set_meta_dict + + ''' + ============================================================== + abstract method start + ============================================================== + ''' + + def get_inference_meta_data_list(self) -> List[dict]: + meta_data_list = list() + for data_name in self.h_params.data.data_config_per_dataset_dict: + meta:list = self.util_data.pickle_load(f'{self.h_params.data.root_path}/{data_name}_test.pkl') + meta_data_list += meta + return meta_data_list + + def get_output_dir_path(self, pretrained_name:str, meta_data:dict) -> Tuple[str,str]: + output_dir_path: str = f'''{self.output_dir}/{self.experiment_name}({pretrained_name})/{meta_data["test_name"]}''' + shared_output_dir_path:str = f'''{self.output_dir}/{self.shared_dir_name}/{meta_data["test_name"]}''' + return output_dir_path, shared_output_dir_path + + def read_data_dict_by_meta_data(self, meta_data:dict) -> dict: + ''' + { + "model_input": + "gt": { + "audio", + "spectrogram" + } + } + ''' + data_dict = dict() + data_dict["gt"] = dict() + data_dict["pred"] = dict() + + def post_process(self, data_dict: dict) -> dict: + return data_dict + + def save_data(self, output_dir_path:str, shared_output_dir_path:str, meta_data:dict, data_dict:dict) -> None: + pass + + @torch.no_grad() + def update_data_dict_by_model_inference(self, data_dict: dict) -> dict: + if type(data_dict["model_input"]) == Tensor: + data_dict["pred"] = self.model(data_dict["model_input"].to(self.device)) + return data_dict + ''' + ============================================================== + abstract method end + ============================================================== + ''' + + def get_model(self, model_class_name:str) -> nn.Module: + return GetModule.get_model(model_class_name) if (model_class_name not in [None,'']) else None + + def inference(self, + pretrained_root_dir:str, + pretrained_dir_name:str, + pretrain_module_name:str + ) -> None: + pretrained_path_list:List[str] = self.get_pretrained_path_list( + pretrain_root_dir= pretrained_root_dir, + pretrain_dir_name = pretrained_dir_name, + pretrain_module_name= pretrain_module_name + ) + + for pretrained_path in pretrained_path_list: + self.pretrained_load(pretrained_path) + pretrained_name:str = UtilData.get_file_name(file_path=pretrained_path) + meta_data_list:List[dict] = self.get_inference_meta_data_list() + for meta_data in tqdm(meta_data_list,desc='inference by meta data'): + output_dir_path, shared_output_dir_path = self.get_output_dir_path(pretrained_name=pretrained_name,meta_data=meta_data) + if output_dir_path is None: continue + data_dict:dict = self.read_data_dict_by_meta_data(meta_data=meta_data) + data_dict = self.update_data_dict_by_model_inference(data_dict) + data_dict:dict = self.post_process(data_dict) + + self.save_data(output_dir_path, shared_output_dir_path, meta_data, data_dict) + + def get_pretrained_path_list(self, + pretrain_root_dir:str, + pretrain_dir_name:str, + pretrain_module_name:str + ) -> List[str]: + pretrain_dir = f"{pretrain_root_dir}/{pretrain_dir_name}" + + if pretrain_module_name in ["all","last_epoch"]: + pretrain_name_list:List[str] = [ + pretrain_module + for pretrain_module in os.listdir(pretrain_dir) + if os.path.splitext(pretrain_module)[-1] in [".pth"] and "checkpoint" not in pretrain_module + ] + pretrain_name_list.sort() + + if pretrain_module_name == "last_epoch": + pretrain_name_list = [pretrain_name_list[-1]] + else: + pretrain_name_list:List[str] = [pretrain_module_name] + + return [f"{pretrain_dir}/{pretrain_name}" for pretrain_name in pretrain_name_list] + + def pretrained_load(self,pretrain_path:str) -> None: + if pretrain_path is None: + return + pretrained_load:dict = torch.load(pretrain_path,map_location='cpu') + self.model.load_state_dict(pretrained_load) + self.model = self.model.to(self.device) + self.model.eval() \ No newline at end of file diff --git a/TorchJaekwon/JupyterNotebookUtil.py b/TorchJaekwon/JupyterNotebookUtil.py new file mode 100644 index 0000000000000000000000000000000000000000..9013442be80bfeb413614e97eca5593c8a9aeb1d --- /dev/null +++ b/TorchJaekwon/JupyterNotebookUtil.py @@ -0,0 +1,201 @@ +from typing import List, Literal, Union, Tuple +try: import IPython.display as ipd +except: print('[error] there is no IPython package') +try: import pandas as pd +except: print('[error] there is no pandas package') + +import re +import librosa + +from TorchJaekwon.Util.UtilAudio import UtilAudio +from TorchJaekwon.Util.UtilAudioMelSpec import UtilAudioMelSpec +from TorchJaekwon.Util.UtilData import UtilData + +LOWER_IS_BETTER_SYMBOL = "↓" +HIGHER_IS_BETTER_SYMBOL = "↑" +PLUS_MINUS_SYMBOL = "±" + +class JupyterNotebookUtil(): + def __init__(self, + output_dir:str = None, + table_data_width:int = None, + audio_sr:int = 44100 + ) -> None: + self.indent:str = ' ' + self.media_idx_dict:dict = {'audio':0, 'img':0} + self.html_start_list:List[str] = [ + '', + '', + '', + '', + '', + '', + '', + '', + '
', + ] + self.html_end_list:List[str] = [ + '
', + '', + '', + ] + self.output_dir:str = output_dir + self.media_save_dir_name:str = 'media' + + mel_spec_config = UtilAudioMelSpec.get_default_mel_spec_config(audio_sr) + self.mel_spec_util = UtilAudioMelSpec(**mel_spec_config) + + def get_table_html_list(self, + dict_list: List[dict], + use_pandas:bool = False + ) -> List[str]: + ''' + Keys will be the table head items + Values will be the table body items + + dict_list = [ + {'name':'test_sample_name', 'model1': html_code/float/str, 'model2': html_code/float/str, ...}, + / + {'name':'model_name', 'metric1': html_code/float/str, 'metric2': html_code/float/str, ...}, + ... + ] + ''' + if use_pandas: + df = pd.DataFrame(dict_list) + return df.to_html(escape=False,index=False) + + html_list = list() + table_head_item_list = list(dict_list[0].keys()) + html_list.append('') + html_list.append('') + html_list.append('') + for table_head_item in table_head_item_list: + html_list.append(f'') + html_list.append('') + html_list.append('') + + html_list.append('') + for html_dict in dict_list: + html_list.append('') + for table_head_item in table_head_item_list: + html_list.append(f'''''') + html_list.append('') + html_list.append('') + html_list.append('
{table_head_item}
{html_dict.get(table_head_item,'')}
') + + return html_list + + def save_html(self, html_list:List[str], file_name:str = 'plot.html') -> None: + final_html_list:list = self.html_start_list + html_list + self.html_end_list + indent_depth:int = 0 + for idx in range(1, len(final_html_list)): + indent_depth += self.get_indent_depth_changed(final_html_list[idx - 1], final_html_list[idx]) + final_html_list[idx] = self.indent * indent_depth + final_html_list[idx] + UtilData.txt_save(f'{self.output_dir}/{file_name}', final_html_list) + + def get_html_text(self, + text:str, + tag:Literal['h1','h2','h3','h4','h5','h6','p'] = 'h1' + ) -> str: + return f'<{tag}>{text}' + + def get_html_img(self, + src_path:str = None, + width:int=150 + ) -> str: #html code + style:str = '' if width is None else f'style="width:{width}px"' + return f'''''' + + def get_media_path(self, type:Literal['audio','img']) -> str: + ext_dict = {'audio':'wav', 'img':'png'} + path_dict = dict() + path_dict['abs'] = f'{self.output_dir}/{self.media_save_dir_name}/{type}_{str(self.media_idx_dict[type]).zfill(5)}.{ext_dict[type]}' + path_dict['relative'] = f'''./{self.media_save_dir_name}{path_dict['abs'].split(self.media_save_dir_name)[-1]}''' + self.media_idx_dict[type] += 1 + return path_dict + + def get_html_audio(self, + audio_path:str = None, + cp_to_html_dir:bool = True, + sample_rate:int = None, + mel_spec_plot:bool = True, + spec_plot:bool = False, + width:int=200 + ) -> Union[str, Tuple[str,str]]: #audio_html_code, img_html_code + style:str = '' if width is None else f'style="width:{width}px"' + if cp_to_html_dir: + audio, sr = UtilAudio.read(audio_path = audio_path, sample_rate=sample_rate) + path_dict = self.get_media_path('audio') + UtilAudio.write(audio_path=path_dict['abs'], audio=audio, sample_rate=sr) + audio_path = path_dict['relative'] + + html_code_dict = dict() + html_code_dict['audio'] = f'''''' + if mel_spec_plot: + mel_spec = self.mel_spec_util.get_hifigan_mel_spec(audio) + if len(mel_spec.shape) == 3: mel_spec = mel_spec[0] + img_path = f'{self.output_dir}/{self.media_save_dir_name}/img_{str(self.media_idx_dict["img"]).zfill(5)}.png' + self.media_idx_dict["img"] += 1 + self.mel_spec_util.mel_spec_plot(save_path=img_path, mel_spec=mel_spec) + img_path = f'./{self.media_save_dir_name}{img_path.split(self.media_save_dir_name)[-1]}' + html_code_dict['mel'] = self.get_html_img(img_path, width) + + if spec_plot: + stft_mag = self.mel_spec_util.stft_torch(audio)["mag"].squeeze() + stft_db = librosa.amplitude_to_db(stft_mag) + path_dict = self.get_media_path('img') + self.mel_spec_util.mel_spec_plot(save_path=path_dict['abs'], mel_spec=stft_db) + html_code_dict['spec'] = self.get_html_img(path_dict['relative'], width) + + return html_code_dict + + def get_html_tag_list(self, html_str:str) -> List[str]: + html_tag_list = re.findall(r']+>', html_str) + for idx in range(len(html_tag_list)): + html_str_split = html_tag_list[idx].split(' ') + if len(html_str_split) > 1: + html_tag_list[idx] = html_str_split[0] + html_str_split[-1] + return html_tag_list + + def get_indent_depth_changed(self, prev_str:str, current_str:str) -> bool: + prev_tag_list = self.get_html_tag_list(prev_str) + current_tag_list = self.get_html_tag_list(current_str) + if len(current_tag_list) == 0 or len(prev_tag_list) == 0: + return 0 + + if prev_tag_list[0] == current_tag_list[0]: + return 0 + + if ' None: + for html_result in html_list: + ipd.display(ipd.HTML(html_result)) + \ No newline at end of file diff --git a/TorchJaekwon/Model/Activation/ActivationFunction.py b/TorchJaekwon/Model/Activation/ActivationFunction.py new file mode 100644 index 0000000000000000000000000000000000000000..bca29dadea17b3802638c701fe0a21b43a72821c --- /dev/null +++ b/TorchJaekwon/Model/Activation/ActivationFunction.py @@ -0,0 +1,13 @@ +import numpy as np +import torch + +class ActivationFunction: + @staticmethod + def exponential_sigmoid(x: torch.Tensor) -> torch.Tensor: + """Exponential sigmoid. + Args: + x: [torch.float32; [...]], input tensors. + Returns: + sigmoid outputs. + """ + return 2.0 * torch.sigmoid(x) ** np.log(10) + 1e-7 \ No newline at end of file diff --git a/TorchJaekwon/Model/Activation/AliasFree1d.py b/TorchJaekwon/Model/Activation/AliasFree1d.py new file mode 100644 index 0000000000000000000000000000000000000000..c73c13c239d342389750b9fc89d6ef3d8080e2b1 --- /dev/null +++ b/TorchJaekwon/Model/Activation/AliasFree1d.py @@ -0,0 +1,30 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. +# Used in BigVGAN + +import torch.nn as nn +from TorchJaekwon.Model.AudioModule.Resample.UpSample1d import UpSample1d +from TorchJaekwon.Model.AudioModule.Resample.DownSample1d import DownSample1d + + +class AliasFree1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x \ No newline at end of file diff --git a/TorchJaekwon/Model/Activation/Snake.py b/TorchJaekwon/Model/Activation/Snake.py new file mode 100644 index 0000000000000000000000000000000000000000..02474fbba8f7b9e1d36c1c71799148ffae7fac0e --- /dev/null +++ b/TorchJaekwon/Model/Activation/Snake.py @@ -0,0 +1,42 @@ +import torch +from torch import nn +from torch.nn import Parameter + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 \ No newline at end of file diff --git a/TorchJaekwon/Model/Activation/SnakeBeta.py b/TorchJaekwon/Model/Activation/SnakeBeta.py new file mode 100644 index 0000000000000000000000000000000000000000..acabad2ed4fa6920587ecf553b4ae845293d4c27 --- /dev/null +++ b/TorchJaekwon/Model/Activation/SnakeBeta.py @@ -0,0 +1,63 @@ +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x \ No newline at end of file diff --git a/TorchJaekwon/Model/AudioModule/ConvReverb.py b/TorchJaekwon/Model/AudioModule/ConvReverb.py new file mode 100644 index 0000000000000000000000000000000000000000..f1175069c3206c4fb64c1e48c8c2803aff049980 --- /dev/null +++ b/TorchJaekwon/Model/AudioModule/ConvReverb.py @@ -0,0 +1,58 @@ +from torch import Tensor +from numpy import ndarray + +import librosa +import soundfile as sf +import torch +import torch.nn as nn + +class ConvReverb(nn.Module): + def __init__(self): + super(ConvReverb,self).__init__() + + def conv_reverb_by_one_ir( + self, + input_signal:Tensor, #(batch,sampletime_length) + input_ir:Tensor #(batch,sampletime_length) + ) -> Tensor: + + zero_padded_input_signal = nn.functional.pad(input_signal, (0, input_ir.shape[-1] - 1)) + input_signal_fft = torch.fft.rfft(zero_padded_input_signal, dim=1) #torch.rfft(zero_padded_input_signal, 1) + + zero_pad_final_fir = nn.functional.pad(input_ir, (0, input_signal.shape[-1] - 1)) + + fir_fft = torch.fft.rfft(zero_pad_final_fir, dim=1) #torch.rfft(zero_pad_final_fir, 1) + output_signal_fft:Tensor = fir_fft * input_signal_fft + + output_signal = torch.fft.irfft(output_signal_fft, dim=1) #torch.irfft(output_signal_fft, 1) + + return output_signal + + def forward( + self, + input_signal:Tensor, #(batch,sampletime_length) + input_ir:Tensor #(batch,sampletime_length) + ) -> Tensor: + assert ((len(input_signal.shape) == 2) or (len(input_signal.shape) == 3)), "input shape is wrong" + if len(input_signal.shape) == 2: + return self.conv_reverb_by_one_ir(input_signal,input_ir) + else: + left_reverb_audio:Tensor = self.conv_reverb_by_one_ir(input_signal[:,0,:],input_ir[:,0,:]).unsqueeze(1) + right_reverb_audio:Tensor = self.conv_reverb_by_one_ir(input_signal[:,1,:],input_ir[:,1,:]).unsqueeze(1) + reverb_audio:Tensor = torch.cat([left_reverb_audio,right_reverb_audio],axis=1) + return reverb_audio + +if __name__ == "__main__": + vocal_audio_dir:str = "/home/jakeoneijk/220101_data/MusDBMainVocal/train/A Classic Education - NightOwl/A Classic Education - NightOwl_Main Vocal.wav" + ir_dir:str = "/home/jakeoneijk/220101_data/DetmoldSRIRStereo/SetB_LSandWFSOrchestra/Data/OpenArray/wfs_R1/S1.wav" + + vocal_audio,sr = librosa.load(vocal_audio_dir,sr=None,mono=False) + ir_audio,sr = librosa.load(ir_dir,sr=sr,mono=False) + + vocal_tensor:Tensor = torch.from_numpy(vocal_audio).unsqueeze(0) + ir_tensor:Tensor = torch.from_numpy(ir_audio).unsqueeze(0) + + conv_reverb = ConvReverb() + reverberated_audio:Tensor = conv_reverb(vocal_tensor,ir_tensor) + reverberated_audio_numpy:ndarray = reverberated_audio.squeeze().numpy() + sf.write("./reverb_audio.wav", data=reverberated_audio_numpy.T, samplerate=sr) \ No newline at end of file diff --git a/TorchJaekwon/Model/AudioModule/FeatureExtract/ConstantQTransform.py b/TorchJaekwon/Model/AudioModule/FeatureExtract/ConstantQTransform.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5c4308fb9014af7af40d4450165e0de0522ae3 --- /dev/null +++ b/TorchJaekwon/Model/AudioModule/FeatureExtract/ConstantQTransform.py @@ -0,0 +1,30 @@ +#type +from torch import Tensor + +import torch.nn as nn + +from TorchJaekwon.Model.AudioModule.Package.nnAudio.cqt import CQT2010v2 + +class ConstantQTransform(nn.Module): + def __init__(self, + sample_rate:int = 22050, + hop_length:int = 512, + fmin:float = 32.7, #C1 ~= 32.70 Hz, fmax = 2 ** (number_of_freq_bins / bins_per_octave) * fmin + number_of_freq_bins:int = 84, #starting at fmin + bins_per_octave:int = 12 + ) -> None: + super().__init__() + self.cqt = CQT2010v2( + sr = sample_rate, + hop_length = hop_length, + fmin = fmin, + n_bins = number_of_freq_bins, + bins_per_octave=bins_per_octave, + trainable=False, + output_format='Magnitude') + + def forward(self, + inputs:Tensor #[torch.float32; [B, T]], input speech signal. + ) -> Tensor: #[torch.float32; [B, bins, T / strides]], CQT magnitudes. + + return self.cqt(inputs[:, None]) \ No newline at end of file diff --git a/TorchJaekwon/Model/AudioModule/FeatureExtract/MelSpectrogram.py b/TorchJaekwon/Model/AudioModule/FeatureExtract/MelSpectrogram.py new file mode 100644 index 0000000000000000000000000000000000000000..97644ee9fce539f0ed88dbcf3d3e70822522ac98 --- /dev/null +++ b/TorchJaekwon/Model/AudioModule/FeatureExtract/MelSpectrogram.py @@ -0,0 +1,64 @@ +from torch import Tensor + +import torch +import torch.nn as nn +from librosa.filters import mel as librosa_mel_fn + +class MelSpectrogram(nn.Module): + def __init__(self, + sample_rate:int, + nfft: int, + hop_size: int, + mel_size:int, + frequency_min:float, + frequency_max:float) -> None: + super().__init__() + # [mel_size, nfft // 2 + 1] + self.nfft,self.hop_size,self.mel_size = nfft,hop_size,mel_size + self.register_buffer( + 'mel_filterbank', + torch.from_numpy(librosa_mel_fn(sr = sample_rate, + n_fft = nfft, + n_mels = mel_size, + fmin = frequency_min, + fmax = frequency_max)).float(), + persistent=False) + self.register_buffer( + 'hann_window', torch.hann_window(nfft), persistent=False) + + def forward(self, + audio: Tensor #[torch.float32; [B, T]], audio signal, [-1, 1]-ranged. + ) -> Tensor: #[torch.float32; [B, mel, T / strides]], mel spectrogram + if torch.min(audio) < -1.: + print('min value is ', torch.min(audio)) + if torch.max(audio) > 1.: + print('max value is ', torch.max(audio)) + + # [BatchSize, nfft // 2 + 1, T / hop_size] + spec:Tensor = torch.stft( audio, + n_fft=self.nfft, + hop_length=self.hop_size, + window=self.hann_window, + center=True, pad_mode='reflect', + return_complex=True) + # [BatchSize, nfft // 2 + 1, T / hop_size] + mag:Tensor = abs(spec) + # [BatchSize, mel_size, T / hop_size] + return torch.matmul(self.mel_filterbank, mag) + + def get_log_mel(self, + audio: Tensor #[torch.float32; [B, T]], audio signal, [-1, 1]-ranged. + ) -> Tensor: #[torch.float32; [B, mel, T / strides]], mel spectrogram + mel_spec:Tensor = self.forward(audio) + return torch.log(mel_spec + 1e-7) + + def get_dynamic_range_compresed_mel( + self, + audio: Tensor #[torch.float32; [B, T]], audio signal, [-1, 1]-ranged. + ) -> Tensor: #[torch.float32; [B, mel, T / strides]], mel spectrogram + #used in hi-fi gan + mel_spec:Tensor = self.forward(audio) + return self.dynamic_range_compression(mel_spec) + + def dynamic_range_compression(self, x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) diff --git a/TorchJaekwon/Model/AudioModule/Filter/Filter.py b/TorchJaekwon/Model/AudioModule/Filter/Filter.py new file mode 100644 index 0000000000000000000000000000000000000000..859a91b2e77a4209c5e59639376907be051cf167 --- /dev/null +++ b/TorchJaekwon/Model/AudioModule/Filter/Filter.py @@ -0,0 +1,50 @@ +from typing import Callable + +import torch +import math + +class Filter: + @staticmethod + def sinc(x: torch.Tensor): + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + @staticmethod + def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + sinc:Callable = torch.sinc if 'sinc' in dir(torch) else Filter.sinc + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter \ No newline at end of file diff --git a/TorchJaekwon/Model/AudioModule/Filter/LowPassFilter1d.py b/TorchJaekwon/Model/AudioModule/Filter/LowPassFilter1d.py new file mode 100644 index 0000000000000000000000000000000000000000..548021cc1256d568cc0d8c84f5c9475e91b4dd91 --- /dev/null +++ b/TorchJaekwon/Model/AudioModule/Filter/LowPassFilter1d.py @@ -0,0 +1,40 @@ +import torch.nn as nn +import torch.nn.functional as F +from TorchJaekwon.Model.AudioModule.Filter.Filter import Filter + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = Filter.kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + + return out \ No newline at end of file diff --git a/TorchJaekwon/Model/AudioModule/Filter/nkandpa2_music_enhancement/MicrophoneEQ.py b/TorchJaekwon/Model/AudioModule/Filter/nkandpa2_music_enhancement/MicrophoneEQ.py new file mode 100644 index 0000000000000000000000000000000000000000..81b91319c3c5025e91eb1ab99dec00e628ff0609 --- /dev/null +++ b/TorchJaekwon/Model/AudioModule/Filter/nkandpa2_music_enhancement/MicrophoneEQ.py @@ -0,0 +1,61 @@ +from torch import Tensor + +import torch +import torch.nn as nn +import numpy as np +from scipy import signal + +from TorchJaekwon.Model.AudioModule.Filter.nkandpa2_music_enhancement.Util import Util + +class MicrophoneEQ(nn.Module): + """ + Apply a random EQ on bands demarcated by `bands` + """ + def __init__(self, low_db=-15, hi_db=15, bands=[200, 1000, 4000], filter_length=8192, rate=16000): + super(MicrophoneEQ, self).__init__() + self.low_db = low_db + self.hi_db = hi_db + self.rate = rate + self.filter_length = filter_length + self.firs = nn.Parameter(self.create_filters(bands)) + + def create_filters(self, bands): + """ + Generate bank of FIR bandpass filters with band cutoffs specified by `bands` + """ + ir = np.zeros([self.filter_length]) + ir[0] = 1 + bands = [35] + bands + fir = np.zeros([len(bands) + 1, self.filter_length]) + for j in range(len(bands)): + freq = bands[j] / (self.rate/2) + bl, al = signal.butter(4, freq, btype='low') + bh, ah = signal.butter(4, freq, btype='high') + fir[j] = signal.lfilter(bl, al, ir) + ir = signal.lfilter(bh, ah, ir) + fir[-1] = ir + pfir = np.square(np.abs(np.fft.fft(fir,axis=1))) + pfir = np.real(np.fft.ifft(pfir, axis=1)) + fir = np.concatenate((pfir[:,self.filter_length//2:self.filter_length], pfir[:,0:self.filter_length//2]), axis=1) + return torch.tensor(fir, dtype=torch.float32) + + def get_eq_filter(self, band_gains): + """ + Apply `band_gains` to bank of FIR bandpass filters to get the final EQ filter + """ + band_gains = 10**(band_gains/20) + eq_filter = (band_gains[:,:,None] * self.firs[None,:,:]).sum(dim=1, keepdim=True) + return eq_filter + + def forward(self, + x, # [batch, signal, length] + gain=None): + gains = self.get_random_gain(batch_size=x.shape[0]) if gain is None else gain + gains = torch.cat((torch.zeros((x.shape[0], 1), device=self.firs.device), gains), dim=1) + eq_filter = self.get_eq_filter(gains) + eq_x = Util.batch_convolution(x, eq_filter, pad_both_sides=True) + return eq_x + + def get_random_gain(self, batch_size:int = 1) -> Tensor: + return (self.hi_db - self.low_db)*torch.rand(batch_size, self.firs.shape[0]-1, device=self.firs.device) + self.low_db + \ No newline at end of file diff --git a/TorchJaekwon/Model/AudioModule/Filter/nkandpa2_music_enhancement/README.md b/TorchJaekwon/Model/AudioModule/Filter/nkandpa2_music_enhancement/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5c5081315f05e2c6b166009e8aceefc373a2dc93 --- /dev/null +++ b/TorchJaekwon/Model/AudioModule/Filter/nkandpa2_music_enhancement/README.md @@ -0,0 +1,2 @@ +This code is from [nkandpa2/music_enhancement](https://github.com/nkandpa2/music_enhancement) +[Music Enhancement via Image Translation and Vocoding](https://arxiv.org/abs/2204.13289) \ No newline at end of file diff --git a/TorchJaekwon/Model/AudioModule/Filter/nkandpa2_music_enhancement/Util.py b/TorchJaekwon/Model/AudioModule/Filter/nkandpa2_music_enhancement/Util.py new file mode 100644 index 0000000000000000000000000000000000000000..26d89621ee8f7475ac66c57f17159a46a182b4b8 --- /dev/null +++ b/TorchJaekwon/Model/AudioModule/Filter/nkandpa2_music_enhancement/Util.py @@ -0,0 +1,137 @@ +from torch import Tensor + +import torch +import torch.nn.functional as F + +class Util: + @staticmethod + def batch_convolution(x, f, pad_both_sides=True): + """ + Do batch-elementwise convolution between a batch of signals `x` and batch of filters `f` + x: (batch_size x channels x signal_length) size tensor + f: (batch_size x channels x filter_length) size tensor + pad_both_sides: Whether to zero-pad x on left and right or only on left (Default: True) + """ + batch_size = x.shape[0] + f = torch.flip(f, (2,)) + if pad_both_sides: + x = F.pad(x, (f.shape[2]//2, f.shape[2]-f.shape[2]//2-1)) + else: + x = F.pad(x, (f.shape[2]-1, 0)) + #TODO: This assumes single-channel audio, fine for now + return F.conv1d(x.view(1, batch_size, -1), f, groups=batch_size).view(batch_size, 1, -1) + + @staticmethod + def augment(sample, rir=None, noise=None, eq_model=None, low_cut_model=None, rate=16000, nsr_range=[-30,-5], normalize=True, eps=1e-6): + sample = Util.perturb_silence(sample, eps=eps) + clean_sample = torch.clone(sample) + if not noise is None: + nsr_target = ((nsr_range[1] - nsr_range[0])*torch.rand(noise.shape[0]) + nsr_range[0]).to(noise) + sample = Util.apply_noise(sample, noise, nsr_target) + if not rir is None: + sample = Util.apply_reverb(sample, rir, None, rate=rate) + if not eq_model is None: + sample = eq_model(sample) + if not low_cut_model is None: + sample = low_cut_model(sample) + if normalize: + sample = 0.95*sample/sample.abs().max(dim=2, keepdim=True)[0] + + return clean_sample, sample + + @staticmethod + def perturb_silence(sample, eps=1e-6): + """ + Some samples have periods of silence which can cause numerical issues when taking log-spectrograms. Add a little noise + """ + return sample + eps*torch.randn_like(sample) + + @staticmethod + def apply_reverb(sample, rir, drr_target, rate=16000): + """ + Convolve batch of samples with batch of room impulse responses scaled to achieve a target direct-to-reverberation ratio + """ + if not drr_target is None: + direct_ir, reverb_ir = Util.decompose_rir(rir, rate=rate) + drr_db = Util.drr(direct_ir, reverb_ir) + scale = 10**((drr_db - drr_target)/20) + reverb_ir_scaled = scale[:, None, None]*reverb_ir + rir_scaled = torch.cat((direct_ir, reverb_ir_scaled), axis=2) + else: + rir_scaled = rir + return Util.batch_convolution(sample, rir_scaled, pad_both_sides=False) + + @staticmethod + def apply_noise(sample, noise, nsr_target, peak=False): + """ + Apply additive noise scaled to achieve target noise-to-signal ratio + """ + if peak: + nsr_curr = Util.pnsr(sample, noise) + noise_flat = noise.view(noise.shape[0], -1) + peak_noise = noise_flat.max(dim=1)[0] - noise_flat.min(dim=1)[0] + scale = 10**((nsr_target - nsr_curr)/20) + else: + nsr_curr = Util.nsr(sample, noise) + scale = torch.sqrt(10**((nsr_target - nsr_curr)/10)) + + return sample + scale[:, None, None]*noise + + @staticmethod + def apply_noise_wrt_snr(sample:Tensor, # [channel, signal_length] + noise:Tensor, # [channel, signal_length] + snr_target:float): + """ + Apply additive noise scaled to achieve target noise-to-signal ratio + """ + snr_curr = Util.nsr(noise, sample) + scale = 1/torch.sqrt(10**((snr_target - snr_curr)/10)) + + return sample + scale * noise + + @staticmethod + def nsr(sample, noise): + """ + Compute noise-to-signal ratio + """ + sample, noise = sample.view(sample.shape[0], -1), noise.view(noise.shape[0], -1) + signal_power = torch.square(sample).mean(dim=1) + noise_power = torch.square(noise).mean(dim=1) + return 10*torch.log10(noise_power/signal_power) + + @staticmethod + def pnsr(sample, noise): + """ + Compute peak noise-to-signal-ratio + """ + sample, noise = sample.view(sample.shape[0], -1), noise.view(noise.shape[0], -1) + peak_noise = noise.max(dim=1)[0] - noise.min(dim=1)[0] + signal = torch.square(sample).mean(dim=1) + return 20*torch.log10(peak_noise) - 10*torch.log10(signal) + + @staticmethod + def drr(direct_ir, reverb_ir): + """ + Compute direct-to-reverberation ratio + """ + direct_ir_flat = direct_ir.view(direct_ir.shape[0], -1) + reverb_ir_flat = reverb_ir.view(reverb_ir.shape[0], -1) + drr_db = 10*torch.log10(torch.square(direct_ir_flat).sum(dim=1)/torch.square(reverb_ir_flat).sum(dim=1)) + return drr_db + + @staticmethod + def decompose_rir(rir, rate=16000, window_ms=5): + direct_window = int(window_ms/1000*rate) + direct_ir, reverb_ir = rir[:,:,:direct_window], rir[:,:,direct_window:] + return direct_ir, reverb_ir + + @staticmethod + def preprocess_rir_wrt_window(rir, #[channel, signal_length] + rate=16000, + window_ms=2.5): + direct_impulse_index = rir.argmax().item() + window_len = int(window_ms/1000*rate) + if direct_impulse_index < window_len: + rir = torch.cat((torch.zeros(1, window_len - direct_impulse_index), rir), dim=1) + rir = rir[:, direct_impulse_index - window_len:] + return rir \ No newline at end of file diff --git a/TorchJaekwon/Model/AudioModule/Resample/DownSample1d.py b/TorchJaekwon/Model/AudioModule/Resample/DownSample1d.py new file mode 100644 index 0000000000000000000000000000000000000000..a3b926d963edae82e94adea50ec60c0339db3bc1 --- /dev/null +++ b/TorchJaekwon/Model/AudioModule/Resample/DownSample1d.py @@ -0,0 +1,16 @@ +import torch.nn as nn +from TorchJaekwon.Model.AudioModule.Filter.LowPassFilter1d import LowPassFilter1d +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx \ No newline at end of file diff --git a/TorchJaekwon/Model/AudioModule/Resample/UpSample1d.py b/TorchJaekwon/Model/AudioModule/Resample/UpSample1d.py new file mode 100644 index 0000000000000000000000000000000000000000..a29a448061e823461b376cc45c24255baf54a1ef --- /dev/null +++ b/TorchJaekwon/Model/AudioModule/Resample/UpSample1d.py @@ -0,0 +1,28 @@ +import torch.nn as nn +from torch.nn import functional as F +from TorchJaekwon.Model.AudioModule.Filter.Filter import Filter + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = Filter.kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x \ No newline at end of file diff --git a/TorchJaekwon/Model/ConditionalLayerNormalization.py b/TorchJaekwon/Model/ConditionalLayerNormalization.py new file mode 100644 index 0000000000000000000000000000000000000000..fac900df44c2f78c5f9b4935038d34679427d553 --- /dev/null +++ b/TorchJaekwon/Model/ConditionalLayerNormalization.py @@ -0,0 +1,20 @@ +from torch import Tensor + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ConditionalLayerNormalization(nn.Module): + def __init__(self, + input_channels:int, + style_condition_channels:int) -> None: + super().__init__() + self.gain_bias_conv1d = nn.Conv1d(style_condition_channels, input_channels * 2, 1) + + def forward(self, + input:Tensor, #[batch,input_channels,N] + style_condition:Tensor #[batch,style_condition_channels,N] + ): + normalized_input:Tensor = F.layer_norm(input,normalized_shape=input.shape[1:]) + weight, bias = self.gain_bias_conv1d(style_condition).chunk(2, dim=1) + return normalized_input * weight + bias \ No newline at end of file diff --git a/TorchJaekwon/Model/Diffusion/DDPM/BetaSchedule.py b/TorchJaekwon/Model/Diffusion/DDPM/BetaSchedule.py new file mode 100644 index 0000000000000000000000000000000000000000..189af0b9ca7bfa42d50296c333b976fb014fedc9 --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/DDPM/BetaSchedule.py @@ -0,0 +1,27 @@ +from numpy import ndarray +from torch import Tensor + +import numpy as np + +class BetaSchedule: + + @staticmethod + def linear(timesteps:int, start:float=1e-4, end:float =2e-2) -> ndarray: + """ + linear schedule + """ + return np.linspace(start, end, timesteps) + + + @staticmethod + def cosine(timesteps:int, s:float=0.008) -> ndarray: + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps:int = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas:ndarray = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return np.clip(betas, a_min=0, a_max=0.999) diff --git a/TorchJaekwon/Model/Diffusion/DDPM/DDPM.py b/TorchJaekwon/Model/Diffusion/DDPM/DDPM.py new file mode 100644 index 0000000000000000000000000000000000000000..b989106c57f909ba0ad35f2925f0cc973ddc6846 --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/DDPM/DDPM.py @@ -0,0 +1,294 @@ +from typing import Union, Callable, Literal, Optional, Tuple +from numpy import ndarray +from torch import Tensor, device + +from tqdm import tqdm +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F + +from TorchJaekwon.GetModule import GetModule +from TorchJaekwon.Util.UtilData import UtilData +from TorchJaekwon.Util.UtilTorch import UtilTorch +from TorchJaekwon.Model.Diffusion.DDPM.DiffusionUtil import DiffusionUtil +from TorchJaekwon.Model.Diffusion.DDPM.BetaSchedule import BetaSchedule + +class DDPM(nn.Module): + def __init__(self, + model_class_name:Optional[str] = None, + model:Optional[nn.Module] = None, + + model_output_type:Literal['noise', 'x_start', 'v_prediction'] = 'noise', + timesteps:int = 1000, + + loss_func:Union[nn.Module, Callable, Tuple[str,str]] = F.mse_loss, # if tuple (package name, func name). ex) (torch.nn.functional, mse_loss) + + betas: Optional[ndarray] = None, + beta_schedule_type:Literal['linear','cosine'] = 'cosine', + beta_arg_dict:dict = dict(), + + unconditional_prob:float = 0, #if unconditional_prob > 0, this model works as classifier free guidance + cfg_scale:Optional[float] = None # classifer free guidance scale + ) -> None: + super().__init__() + if model_class_name is not None: + self.model = GetModule.get_model(model_name = model_class_name) + else: + self.model:nn.Module = model + self.model_output_type:Literal['noise', 'x_start', 'v_prediction'] = model_output_type + + self.loss_func:Union[nn.Module, Callable] = loss_func + + self.timesteps:int = timesteps + self.set_noise_schedule(betas=betas, beta_schedule_type=beta_schedule_type, beta_arg_dict=beta_arg_dict, timesteps=timesteps) + + self.unconditional_prob:float = unconditional_prob + self.cfg_scale:Optional[float] = cfg_scale + + def set_noise_schedule(self, + betas: Optional[ndarray] = None, + beta_schedule_type:Literal['linear','cosine'] = 'linear', + beta_arg_dict:dict = dict(), + timesteps:int = 1000, + ) -> None: + if betas is None: + beta_arg_dict.update({'timesteps':timesteps}) + betas = getattr(BetaSchedule,beta_schedule_type)(**beta_arg_dict) + + alphas:ndarray = 1. - betas + alphas_cumprod:ndarray = np.cumprod(alphas, axis=0) + alphas_cumprod_prev:ndarray = np.append(1., alphas_cumprod[:-1]) + + self.betas:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'betas', value = betas) + self.alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'alphas_cumprod', value = alphas_cumprod) + self.alphas_cumprod_prev:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'alphas_cumprod_prev', value = alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_alphas_cumprod', value = np.sqrt(alphas_cumprod)) + self.sqrt_one_minus_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_one_minus_alphas_cumprod', value = np.sqrt(1. - alphas_cumprod)) + self.log_one_minus_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'log_one_minus_alphas_cumprod', value = np.log(1. - alphas_cumprod)) + self.sqrt_recip_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_recip_alphas_cumprod', value = np.sqrt(1. / alphas_cumprod)) + self.sqrt_recipm1_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_recipm1_alphas_cumprod', value = np.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_variance', value = posterior_variance) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_log_variance_clipped', value = np.log(np.maximum(posterior_variance, 1e-20))) + self.posterior_mean_coef1:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_mean_coef1', value = betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + self.posterior_mean_coef2:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_mean_coef2', value = (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)) + + def forward(self, + x_start:Optional[Tensor] = None, + x_shape:Optional[tuple] = None, + cond:Optional[Union[dict,Tensor]] = None, + is_cond_unpack:bool = False, + stage: Literal['train', 'infer'] = 'train' + ) -> Tensor: # return loss value or sample + ''' + train diffusion model. + return diffusion loss + ''' + x_start, cond, additional_data_dict = self.preprocess(x_start, cond) + if stage == 'train' and x_start is not None: + if x_shape is None: x_shape = x_start.shape + batch_size:int = x_shape[0] + input_device:device = x_start.device + t:Tensor = torch.randint(0, self.timesteps, (batch_size,), device=input_device).long() + if DDPM.make_decision(self.unconditional_prob): + cond:Optional[Union[dict,Tensor]] = self.get_unconditional_condition(cond=cond, condition_device=input_device) + return self.p_losses(x_start, cond, is_cond_unpack, t) + else: + return self.infer(x_shape = x_shape, cond = cond, is_cond_unpack = is_cond_unpack, additional_data_dict = additional_data_dict) + + def p_losses(self, + x_start:Tensor, + cond:Optional[Union[dict,Tensor]], + is_cond_unpack:bool, + t:Tensor, + noise:Optional[Tensor] = None): + noise:Tensor = UtilData.default(noise, lambda: torch.randn_like(x_start)) + x_noisy:Tensor = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output:Tensor = self.apply_model(x_noisy, t, cond, is_cond_unpack) + + if self.model_output_type == 'x_start': + target:Tensor = x_start + elif self.model_output_type == 'noise': + target:Tensor = noise + elif self.model_output_type == 'v_prediction': + target:Tensor = self.get_v(x_start, noise, t) + else: + print(f'''model output type is {self.model_output_type}. It should be in [x_start, noise]''') + raise NotImplementedError() + if target.shape != model_output.shape: print(f'warning: target shape({target.shape}) and model shape({model_output.shape}) are different') + return self.loss_func(target, model_output) + + def get_v(self, x, noise, t): + ''' + Progressive Distillation for Fast Sampling of Diffusion Models + https://arxiv.org/abs/2202.00512 + ''' + return ( + DiffusionUtil.extract(self.sqrt_alphas_cumprod, t, x.shape) * noise + - DiffusionUtil.extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def q_sample(self, x_start:Tensor, t:Tensor, noise=None) -> Tensor: + ''' + noisy x sample for forward process + ''' + noise = UtilData.default(noise, lambda: torch.randn_like(x_start)) + return ( + DiffusionUtil.extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + DiffusionUtil.extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = DiffusionUtil.extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = DiffusionUtil.extract(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = DiffusionUtil.extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + @torch.no_grad() + def infer(self, + x_shape:tuple, + cond:Optional[Union[dict,Tensor]], + is_cond_unpack:bool, + additional_data_dict:dict): + if x_shape is None: x_shape = self.get_x_shape(cond) + model_device:device = UtilTorch.get_model_device(self.model) + x:Tensor = torch.randn(x_shape, device = model_device) + for i in tqdm(reversed(range(0, self.timesteps)), desc='sample time step', total=self.timesteps): + x = self.p_sample(x = x, t = torch.full((x_shape[0],), i, device= model_device, dtype=torch.long), cond = cond, is_cond_unpack = is_cond_unpack) + + return self.postprocess(x, additional_data_dict = additional_data_dict) + + @torch.no_grad() + def p_sample(self, + x:Tensor, + t:Tensor, + cond:Optional[Union[dict,Tensor]], + is_cond_unpack:bool, + clip_denoised:bool = False, # dangerous if True + repeat_noise:bool = False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, cond = cond, is_cond_unpack = is_cond_unpack, clip_denoised = clip_denoised) + noise = DiffusionUtil.noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + def p_mean_variance(self, + x:Tensor, + t:Tensor, + cond:Optional[Union[dict,Tensor]], + is_cond_unpack:bool, + clip_denoised: bool) -> Tuple[Tensor]: + + model_output:Tensor = self.apply_model(x, t, cond, is_cond_unpack, cfg_scale=self.cfg_scale) + if self.model_output_type == "noise": + x_recon = self.predict_x_start_from_noise(x, t=t, noise=model_output) + elif self.model_output_type == 'x_start': + x_recon = model_output + elif self.model_output_type == 'v_prediction': + x_recon = self.predict_x_start_from_v(x, t=t, v=model_output) + + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + def predict_x_start_from_noise(self, x_t, t, noise): + return ( + DiffusionUtil.extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + DiffusionUtil.extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def predict_x_start_from_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + DiffusionUtil.extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - DiffusionUtil.extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_noise_from_v(self, x_t, t, v): + return ( + DiffusionUtil.extract(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + DiffusionUtil.extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) + * x_t + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + DiffusionUtil.extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + DiffusionUtil.extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = DiffusionUtil.extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = DiffusionUtil.extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def preprocess(self, x_start:Tensor, cond:Optional[Union[dict,Tensor]] = None) -> Tuple[Tensor, Optional[Union[dict,Tensor]], dict]: + return x_start, cond, None + + def postprocess(self, x:Tensor, additional_data_dict:dict) -> Tensor: + return x + + def apply_model(self, + x:Tensor, + t:Tensor, + cond:Optional[Union[dict,Tensor]], + is_cond_unpack:bool, + cfg_scale:Optional[float] = None + ) -> Tensor: + if cfg_scale is None or cfg_scale == 1.0: + if cond is None: + return self.model(x, t) + elif is_cond_unpack: + return self.model(x, t, **cond) + else: + return self.model(x, t, cond) + else: + model_conditioned_output = self.model(x, t, **cond) if is_cond_unpack else self.model(x, t, cond) + unconditional_conditioning = self.get_unconditional_condition(cond=cond) + model_unconditioned_output = self.model(x, t, **unconditional_conditioning) if is_cond_unpack else self.model(x, t, unconditional_conditioning) + return model_unconditioned_output + cfg_scale * (model_conditioned_output - model_unconditioned_output) + + @staticmethod + def make_decision(probability:float #[0,1] + ) -> bool: + if probability == 0: + return False + if float(torch.rand(1)) < probability: + return True + else: + return False + + def get_unconditional_condition(self, + cond:Optional[Union[dict,Tensor]] = None, + cond_shape:Optional[tuple] = None, + condition_device:Optional[device] = None + ) -> Tensor: + print('Default Unconditional Condition. You might wanna overwrite this function') + if cond_shape is None: cond_shape = cond.shape + if cond is not None and isinstance(cond,Tensor): condition_device = cond.device + return (-11.4981 + torch.zeros(cond_shape)).to(condition_device) + + def get_x_shape(self, cond:Optional[Union[dict,Tensor]] = None): + return None + +if __name__ == '__main__': + DDPM(model = 'debug') + + + + diff --git a/TorchJaekwon/Model/Diffusion/DDPM/DDPMLearningVariances.py b/TorchJaekwon/Model/Diffusion/DDPM/DDPMLearningVariances.py new file mode 100644 index 0000000000000000000000000000000000000000..b9f9797f2e722a83cbd466c5badbae72bb0e3433 --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/DDPM/DDPMLearningVariances.py @@ -0,0 +1,194 @@ +''' +2021_ICML_Improved denoising diffusion probabilistic models +Code Reference: https://github.com/facebookresearch/DiT +''' +#type +from typing import Optional, Union, Dict +from torch import Tensor +#package +import torch +import numpy as np +#torchjaekwon +from TorchJaekwon.Model.Diffusion.DDPM.DiffusionUtil import DiffusionUtil +from TorchJaekwon.Util.UtilTorch import UtilTorch +from TorchJaekwon.Util.UtilData import UtilData +from TorchJaekwon.Model.Diffusion.DDPM.DDPM import DDPM + +class DDPMLearningVariances(DDPM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @torch.no_grad() + def p_sample(self, + x:Tensor, + t:Tensor, + cond:Optional[Union[dict,Tensor]], + is_cond_unpack:bool, + clip_denoised:bool = True, + repeat_noise:bool = False): + b, *_, device = *x.shape, x.device + out = self.p_mean_variance( + x = x, + t = t, + cond = cond, + is_cond_unpack = is_cond_unpack, + clip_denoised=clip_denoised, + cfg_scale=self.cfg_scale + + ) + noise = DiffusionUtil.noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + return out["pred_mean"] + nonzero_mask * torch.exp(0.5 * out["pred_log_variance"]) * noise + + def p_losses(self, + x_start:Tensor, + cond:Optional[Union[dict,Tensor]], + is_cond_unpack:bool, + t:Tensor, + noise:Optional[Tensor] = None): + noise:Tensor = UtilData.default(noise, lambda: torch.randn_like(x_start)) + x_noisy:Tensor = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output:Tensor = self.apply_model(x_noisy, t, cond, is_cond_unpack) + + batch_size, channel_size = x_noisy.shape[:2] + assert model_output.shape == (batch_size, channel_size * 2, *x_noisy.shape[2:]), 'Model output size is expected to be (batch_size, channel_size * 2, ...), because it also predicts variance.' + model_output, model_var_values = torch.split(model_output, channel_size, dim=1) + # Learn the variance using the variational bound, but don't let it affect our mean prediction. + mean_frozen_output = torch.cat([model_output.detach(), model_var_values], dim=1) + + + vlb_loss = self.vb_terms_bpd(x_start=x_start, + x_t=x_noisy, + t=t, + cond=cond, + is_cond_unpack=is_cond_unpack, + model_output=mean_frozen_output, + clip_denoised=False, + )["output"] + + if self.model_output_type == 'x_start': + target:Tensor = x_start + elif self.model_output_type == 'noise': + target:Tensor = noise + elif self.model_output_type == 'v_prediction': + target:Tensor = self.get_v(x_start, noise, t) + else: + print(f'''model output type is {self.model_output_type}. It should be in [x_start, noise]''') + raise NotImplementedError() + if target.shape != model_output.shape: print(f'warning: target shape({target.shape}) and model shape({model_output.shape}) are different') + + return (self.loss_func(target, model_output) + vlb_loss).mean() + + def vb_terms_bpd(self, + x_start, + x_t, + t, + cond:Optional[Union[dict,Tensor]], + is_cond_unpack:bool, + model_output:Optional[Tensor] = None, + clip_denoised=True, + ): + """ + Get a term for the variational lower-bound. + bits per dimension (bpd). + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior( x_start=x_start, x_t=x_t, t=t ) + out = self.p_mean_variance( + x = x_t, + t = t, + cond = cond, + is_cond_unpack=is_cond_unpack, + model_output = model_output, + clip_denoised=clip_denoised, + ) + kl = UtilTorch.kl_div_gaussian( true_mean, true_log_variance_clipped, out["pred_mean"], out["pred_log_variance"]) + kl = UtilTorch.mean_flat(kl) / np.log(2.0) + + decoder_nll = -DiffusionUtil.discretized_gaussian_log_likelihood( + x_start, means=out["pred_mean"], log_scales=0.5 * out["pred_log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = UtilTorch.mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = torch.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_x_start": out["pred_x_start"]} + + def p_mean_variance(self, + x:Tensor, + t:Tensor, + cond:Optional[Union[dict,Tensor]], + is_cond_unpack:bool, + model_output:Optional[Tensor] = None, + denoised_fn:callable = None, + clip_denoised: bool = True, + cfg_scale = None + ) -> Dict[str,Tensor]: + B, C = x.shape[:2] + assert t.shape == (B,) + if model_output is None: model_output:Tensor = self.apply_model(x, t, cond, is_cond_unpack, cfg_scale) + + assert model_output.shape == (B, C * 2, *x.shape[2:]) + + #model learn variance + model_output, model_var_values = torch.split(model_output, C, dim=1) + min_log = DiffusionUtil.extract(self.posterior_log_variance_clipped, t, x.shape) + max_log = DiffusionUtil.extract(torch.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = torch.exp(model_log_variance) + + if self.model_output_type == "noise": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_output) + elif self.model_output_type == 'x_start': + x_recon = model_output + + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, _, _ = self.q_posterior(x_start=x_recon, x_t=x, t=t) + assert model_mean.shape == model_log_variance.shape == x_recon.shape == x.shape + return { + "pred_mean": model_mean, + "pred_variance": model_variance, + "pred_log_variance": model_log_variance, + "pred_x_start": x_recon, + } + + def apply_model(self, + x:Tensor, + t:Tensor, + cond:Optional[Union[dict,Tensor]], + is_cond_unpack:bool, + cfg_scale:Optional[float] = None + ) -> Tensor: + if cfg_scale is None or cfg_scale == 1.0: + if cond is None: + return self.model(x, t) + elif is_cond_unpack: + return self.model(x, t, **cond) + else: + return self.model(x, t, cond) + else: + unconditional_conditioning = self.get_unconditional_condition(cond=cond) + cond_and_uncond = torch.cat([cond, unconditional_conditioning], dim=0) + x_for_cond_and_uncond = torch.cat([x, x], dim=0) + model_output = self.model(x_for_cond_and_uncond, t, **cond_and_uncond) if is_cond_unpack else self.model(x_for_cond_and_uncond, t, cond_and_uncond) + + eps, var = torch.split(model_output, model_output.shape[1] // 2, dim=1) + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + cond_var, _ = torch.split(var, len(var) // 2, dim=0) + + cfg_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + return torch.cat([cfg_eps, cond_var], dim=1) diff --git a/TorchJaekwon/Model/Diffusion/DDPM/DDPMLossVLB.py b/TorchJaekwon/Model/Diffusion/DDPM/DDPMLossVLB.py new file mode 100644 index 0000000000000000000000000000000000000000..05f245acc0bf56f78c551174272a612f7487c860 --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/DDPM/DDPMLossVLB.py @@ -0,0 +1,170 @@ +from typing import Union, Callable, Literal, Optional, Tuple +from numpy import ndarray +from torch import Tensor, device + +from tqdm import tqdm +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F + +from TorchJaekwon.GetModule import GetModule +from TorchJaekwon.Util.UtilData import UtilData +from TorchJaekwon.Util.UtilTorch import UtilTorch +from TorchJaekwon.Model.Diffusion.DDPM.DiffusionUtil import DiffusionUtil +from TorchJaekwon.Model.Diffusion.DDPM.BetaSchedule import BetaSchedule +from TorchJaekwon.Model.Diffusion.DDPM.DDPM import DDPM + +class DDPMLossVLB(DDPM): + def __init__(self, + use_vlb_loss:bool = True, + loss_simple_weight:float=1.0, + original_elbo_weight:float=0.0, + logvar_init:float=0.0, + learn_logvar:bool=False, + *args, + **kwargs): + self.use_vlb_loss = use_vlb_loss + super().__init__(*args, **kwargs) + + self.loss_simple_weight:float = loss_simple_weight + self.original_elbo_weight:float = original_elbo_weight + self.logvar:float = torch.full(fill_value=logvar_init, size=(self.timesteps,)) + self.learn_logvar:bool = learn_logvar + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + else: + self.logvar = nn.Parameter(self.logvar, requires_grad=False) + + def set_noise_schedule(self, + betas: Optional[ndarray] = None, + beta_schedule_type:Literal['linear','cosine'] = 'linear', + beta_arg_dict:dict = dict(), + timesteps:int = 1000, + ) -> None: + if betas is None: + beta_arg_dict.update({'timesteps':timesteps}) + betas = getattr(BetaSchedule,beta_schedule_type)(**beta_arg_dict) + + alphas:ndarray = 1. - betas + alphas_cumprod:ndarray = np.cumprod(alphas, axis=0) + alphas_cumprod_prev:ndarray = np.append(1., alphas_cumprod[:-1]) + + self.betas:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'betas', value = betas) + self.alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'alphas_cumprod', value = alphas_cumprod) + self.alphas_cumprod_prev:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'alphas_cumprod_prev', value = alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_alphas_cumprod', value = np.sqrt(alphas_cumprod)) + self.sqrt_one_minus_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_one_minus_alphas_cumprod', value = np.sqrt(1. - alphas_cumprod)) + self.log_one_minus_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'log_one_minus_alphas_cumprod', value = np.log(1. - alphas_cumprod)) + self.sqrt_recip_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_recip_alphas_cumprod', value = np.sqrt(1. / alphas_cumprod)) + self.sqrt_recipm1_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_recipm1_alphas_cumprod', value = np.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_variance', value = posterior_variance) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_log_variance_clipped', value = np.log(np.maximum(posterior_variance, 1e-20))) + self.posterior_mean_coef1:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_mean_coef1', value = betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + self.posterior_mean_coef2:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_mean_coef2', value = (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)) + + if self.use_vlb_loss: + if self.model_output_type == 'noise': + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * torch.tensor(alphas, dtype=torch.float32) + * (1 - self.alphas_cumprod) + ) + elif self.model_output_type == 'x_start': + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) + elif self.model_output_type == 'v_prediction': + lvlb_weights = torch.ones_like( + self.betas**2 + / ( + 2 + * self.posterior_variance + * torch.tensor(alphas, dtype=torch.float32) + * (1 - self.alphas_cumprod) + ) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) + self.lvlb_weights = self.lvlb_weights + assert not torch.isnan(self.lvlb_weights).all() + + def p_losses(self, + x_start:Tensor, + cond:Optional[Union[dict,Tensor]], + is_cond_unpack:bool, + t:Tensor, + noise:Optional[Tensor] = None): + if not self.use_vlb_loss: + return super().p_losses(x_start, cond, is_cond_unpack, t, noise) + + noise:Tensor = UtilData.default(noise, lambda: torch.randn_like(x_start)) + x_noisy:Tensor = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output:Tensor = self.apply_model(x_noisy, t, cond, is_cond_unpack) + + if self.model_output_type == 'x_start': + target:Tensor = x_start + elif self.model_output_type == 'noise': + target:Tensor = noise + elif self.model_output_type == 'v_prediction': + target:Tensor = self.get_v(x_start, noise, t) + else: + print(f'''model output type is {self.model_output_type}. It should be in [x_start, noise]''') + raise NotImplementedError() + if target.shape != model_output.shape: print(f'warning: target shape({target.shape}) and model shape({model_output.shape}) are different') + + loss_dict = dict() + loss_simple:Tensor = self.get_loss(model_output, target, mean=False) + loss_simple = loss_simple.mean(dim = list(range(len(loss_simple.shape)))[1:]) + loss_dict.update({f"loss_simple": loss_simple.mean()}) + + logvar_t = self.logvar[t] + loss = loss_simple / torch.exp(logvar_t) + logvar_t + + if self.learn_logvar: + loss_dict.update({f"loss_gamma": loss.mean()}) + loss_dict.update({"logvar": self.logvar.data.mean()}) + + loss = self.loss_simple_weight * loss.mean() + + loss_vlb:Tensor = self.get_loss(model_output, target, mean=False) + loss_vlb = loss_vlb.mean(dim=list(range(len(loss_vlb.shape)))[1:]) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f"loss_vlb": loss_vlb}) + loss += self.original_elbo_weight * loss_vlb + loss_dict.update({f"loss": loss}) + return loss_dict + + def get_loss(self, pred:Tensor, target:Tensor, mean=True) -> Tensor: + if self.loss_func == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_func == F.mse_loss: + if mean: + loss = self.loss_func(target, pred) + else: + loss = self.loss_func(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + +if __name__ == '__main__': + ddpm = DDPMLossVLB(model = lambda x, t: x, model_output_type = 'v_prediction') + ddpm.p_losses(x_start = torch.randn(2,3,64,64), cond = None, is_cond_unpack = False, t = torch.tensor([30, 23])) + print('finish') \ No newline at end of file diff --git a/TorchJaekwon/Model/Diffusion/DDPM/DiffusionUtil.py b/TorchJaekwon/Model/Diffusion/DDPM/DiffusionUtil.py new file mode 100644 index 0000000000000000000000000000000000000000..83b29b751291ddec5333130df85a77740f3b0873 --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/DDPM/DiffusionUtil.py @@ -0,0 +1,54 @@ +from torch import Tensor,device + +import torch +import numpy as np + +class DiffusionUtil: + @staticmethod + def extract(array:Tensor, t, x_shape): + batch_size, *_ = t.shape + out = array.gather(dim = -1, index = t).contiguous() + return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).contiguous() + + @staticmethod + def noise_like(shape:tuple, device:device, repeat:bool = False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + @staticmethod + def discretized_gaussian_log_likelihood(x, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = DiffusionUtil.approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = DiffusionUtil.approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x < -0.999, + log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + + @staticmethod + def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) diff --git a/TorchJaekwon/Model/Diffusion/DDPM/Module/SinusoidalPosEmb.py b/TorchJaekwon/Model/Diffusion/DDPM/Module/SinusoidalPosEmb.py new file mode 100644 index 0000000000000000000000000000000000000000..461a522d14d6d9bfc158df5331afff4e2e5fc62c --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/DDPM/Module/SinusoidalPosEmb.py @@ -0,0 +1,18 @@ +import math + +import torch +import torch.nn as nn + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim:int = 256): + super().__init__() + self.dim:int = dim + + def forward(self, diffusion_step: torch.Tensor) -> torch.Tensor: + device = diffusion_step.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = diffusion_step[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb \ No newline at end of file diff --git a/TorchJaekwon/Model/Diffusion/Distillation/FlashDiffusion/FlashDiffusion.py b/TorchJaekwon/Model/Diffusion/Distillation/FlashDiffusion/FlashDiffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..02f86c759743f614277d48826ee6eca6fc5280fa --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/Distillation/FlashDiffusion/FlashDiffusion.py @@ -0,0 +1,656 @@ +import logging +from copy import deepcopy +from typing import Any, Dict, List, Tuple, Union, Literal + +try: import lpips +except: print('lpips is not installed') +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from TorchJaekwon.Model.Diffusion.DDPM.DDPM import DDPM +from tqdm import tqdm + +from TorchJaekwon.Model.Diffusion.External.diffusers.DiffusersWrapper import DiffusersWrapper +from TorchJaekwon.Model.Diffusion.External.diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler +from TorchJaekwon.Model.Diffusion.Distillation.FlashDiffusion.Utils import gaussian_mixture, append_dims, extract_into_tensor + + +class FlashDiffusion(nn.Module): + + def __init__( + self, + student_denoiser: DDPM, + teacher_denoiser: DDPM, + discriminator: torch.nn.Module = None, + teacher_noise_scheduler_class = DPMSolverMultistepScheduler, + distill_loss_type:Literal['l2', 'l1', 'lpips'] = 'l2', + gan_loss_type:Literal['lsgan', 'wgan', 'hinge', 'non-saturating'] = 'lsgan', + use_dmd_loss:bool = True, + distill_loss_scale:list = [1.0, 1.0, 1.0, 1.0], + adversarial_loss_scale:list = [0.0, 0.1, 0.2, 0.3], + dmd_loss_scale:list = [0.0, 0.3, 0.5, 0.7], + use_teacher_as_real:bool = False, + switch_teacher:bool = False, + K:list = [32, 32, 32, 32], + num_iterations_per_K:list=[5000, 5000, 5000, 5000], + timestep_distribution='mixture', + mode_probs=[[0.0, 0.0, 0.5, 0.5], + [0.1, 0.3, 0.3, 0.3], + [0.25, 0.25, 0.25, 0.25], + [0.4, 0.2, 0.2, 0.2]], + guidance_scale_min:float=3.0, + guidance_scale_max:float=13.0, + mixture_num_components:int = 4, + mixture_var:float = 0.5, + ): + super().__init__() + + self.student_denoiser = student_denoiser + self.teacher_denoiser = teacher_denoiser + teacher_noise_scheduler_args = self.get_diffuser_scheduler_config(teacher_denoiser) + self.teacher_noise_scheduler = teacher_noise_scheduler_class(**teacher_noise_scheduler_args) + + self.K = K + self.guidance_scale_min = [guidance_scale_min] * len(self.K) + self.guidance_scale_max = [guidance_scale_max] * len(self.K) + + self.num_iterations_per_K = num_iterations_per_K + self.distill_loss_type = distill_loss_type + self.timestep_distribution = timestep_distribution + self.iter_steps = 0 + self.mixture_num_components = [mixture_num_components] * len(self.K) + self.mixture_var = [mixture_var] * len(self.K) + self.use_dmd_loss = use_dmd_loss + self.dmd_loss_scale = dmd_loss_scale + self.distill_loss_scale = distill_loss_scale + self.discriminator = discriminator + self.adversarial_loss_scale = adversarial_loss_scale + self.gan_loss_type = gan_loss_type + self.mode_probs = mode_probs + self.use_teacher_as_real = use_teacher_as_real + self.switch_teacher = switch_teacher + self.disc_update_counter = 0 + + if self.discriminator is None: + logging.warning( + "No discriminator provided. Adversarial loss will be ignored." + ) + self.use_adversarial_loss = False + else: + self.use_adversarial_loss = True + + self.disc_backbone = self.teacher_denoiser + + if self.distill_loss_type == "lpips": + self.lpips = lpips.LPIPS(net="vgg") + + self.K_steps = np.cumsum(self.num_iterations_per_K) + self.K_prev = self.K[0] + + self.register_buffer( "sqrt_alpha_cumprod", torch.sqrt(self.teacher_denoiser.alphas_cumprod),) + self.register_buffer( "sigmas", torch.sqrt(1 - self.teacher_denoiser.alphas_cumprod),) + + def get_diffuser_scheduler_config(self, ddpm_module: DDPM): + output_type_dict = { + 'v_prediction': 'v_prediction', + 'noise': 'epsilon', + 'x_start': 'sample' + } + return { + 'num_train_timesteps': ddpm_module.timesteps, + 'trained_betas': ddpm_module.betas, + 'prediction_type': output_type_dict[ddpm_module.model_output_type], + 'timestep_spacing': "trailing" + } + + def _encode_inputs(self, batch: Dict[str, Any]): + """ + Encode the inputs using the VAE + """ + with torch.no_grad(): + vae_inputs = batch[self.vae.config.input_key] + return self.vae.encode(vae_inputs) + + def _get_timesteps( + self, num_samples: int = 1, K: int = 1, K_step: int = 1, device="cpu" + ): + # Get the timesteps for the current K + self.teacher_noise_scheduler.set_timesteps(K) + + if self.timestep_distribution == "uniform": + prob = torch.ones(K) / K + elif self.timestep_distribution == "gaussian": + prob = [torch.exp(-torch.tensor([(i - K / 2) ** 2 / K])) for i in range(K)] + prob = torch.tensor(prob) / torch.sum(torch.tensor(prob)) + elif self.timestep_distribution == "mixture": + mixture_num_components = self.mixture_num_components[K_step] + mode_probs = self.mode_probs[K_step] + + # Define targeted timesteps + locs = [ + i * (K // mixture_num_components) + for i in range(0, mixture_num_components) + ] + mixture_var = self.mixture_var[K_step] + prob = [ + gaussian_mixture( + K, + locs=locs, + var=mixture_var, + mode_probs=mode_probs, + )(i) + for i in range(K) + ] + prob = torch.tensor(prob) / torch.sum(torch.tensor(prob)) + + start_idx = torch.multinomial(prob, 1) + + # start_idx = torch.randint(0, len(self.teacher_noise_scheduler.timesteps), (1,)) + + start_timestep = ( + self.teacher_noise_scheduler.timesteps[start_idx] + .to(device) + .repeat(num_samples) + ) + + return start_idx, start_timestep + + def forward(self, + x_start, + cond, + is_cond_unpack:bool = False, + train_stage:Literal['generator', 'discriminator'] = 'generator' + ): + self.iter_steps += 1 + z, preprocessed_cond, additional_data_dict = self.teacher_denoiser.preprocess(x_start, cond) + + # Get conditioning + conditioning = preprocessed_cond + student_conditioning = preprocessed_cond + if DDPM.make_decision(self.student_denoiser.unconditional_prob): + student_conditioning = self.student_denoiser.get_unconditional_condition(cond=cond) + student_conditioning = self.teacher_denoiser.preprocess(None, student_conditioning)[1] + + # Get K for the current step + if self.iter_steps > self.K_steps[-1]: + K_step = len(self.K) - 1 + else: + K_step = np.argmax(self.iter_steps < self.K_steps) + K = self.K[K_step] + guidance_min = self.guidance_scale_min[K_step] + guidance_max = self.guidance_scale_max[K_step] + if K != self.K_prev: + self.K_prev = K + if self.switch_teacher: + print("Switching teacher") + self.teacher_denoiser = deepcopy(self.student_denoiser) + self.teacher_denoiser.freeze() + + # Create noisy samples + noise = torch.randn_like(z) + + # Sample the timesteps + start_idx, start_timestep = self._get_timesteps( + num_samples=z.shape[0], K=K, K_step=K_step, device=z.device + ) + + if start_idx == 0: + noisy_sample_init = noise + noisy_sample_init *= self.teacher_noise_scheduler.init_noise_sigma + noisy_sample_init_student = noise + + else: + # Add noise to sample + noisy_sample_init = self.teacher_noise_scheduler.add_noise( + z, noise, start_timestep + ) + noisy_sample_init_student = noisy_sample_init + + noisy_sample_init_ = self.teacher_noise_scheduler.scale_model_input( + noisy_sample_init_student, start_timestep + ) + + # Get student denoiser output + student_model_output:torch.Tensor = self.student_denoiser.apply_model(noisy_sample_init_, start_timestep, student_conditioning, is_cond_unpack ) + + c_skip, c_out = self._scalings_for_boundary_conditions(start_timestep) + + c_skip = append_dims(c_skip, noisy_sample_init_student.ndim) + c_out = append_dims(c_out, noisy_sample_init_student.ndim) + + student_output = self._predicted_x_0( + student_model_output, + start_timestep.type(torch.int64), + noisy_sample_init_student, + DiffusersWrapper.get_diffusers_output_type_name(self.student_denoiser), + self.sqrt_alpha_cumprod.to(z.device), + self.sigmas.to(z.device), + z, + ) + + noisy_sample = noisy_sample_init.clone().detach() + + guidance_scale = ( + torch.rand(1).to(z.device) * (guidance_max - guidance_min) + guidance_min + ) + + with torch.no_grad(): + for t in self.teacher_noise_scheduler.timesteps[start_idx:]: + timestep = torch.tensor([t], device=z.device).repeat(z.shape[0]) + + noisy_sample_ = self.teacher_noise_scheduler.scale_model_input( + noisy_sample, t + ) + teacher_model_output:torch.Tensor = self.teacher_denoiser.apply_model(noisy_sample_, timestep, conditioning, is_cond_unpack , guidance_scale) + + # Make one step on the reverse diffusion process + noisy_sample = self.teacher_noise_scheduler.step(teacher_model_output, #noise_pred, + t, + noisy_sample, + return_dict=False + )[0] + + teacher_output = noisy_sample + + student_output = c_skip * noisy_sample_init + c_out * student_output + + distill_loss = self._distill_loss(student_output, teacher_output) + + loss = ( + distill_loss + * self.distill_loss_scale[K_step] + ) + + if self.use_dmd_loss: + dmd_loss = self._dmd_loss( + student_output, + student_conditioning, + conditioning, + is_cond_unpack, + K, + K_step, + ) + loss += dmd_loss * self.dmd_loss_scale[K_step] + + if self.use_adversarial_loss: + gan_loss = self._gan_loss( + z, + student_output, + teacher_output, + conditioning, + is_cond_unpack, + train_stage=train_stage, + ) + print("GAN loss", gan_loss) + loss += self.adversarial_loss_scale[K_step] * gan_loss[0] + loss_disc = gan_loss[1] + + result_dict = { + "loss_dict": { + 'gen_total': loss, + }, + "teacher_output": teacher_output, + "student_output": student_output, + "noisy_sample": noisy_sample_init, + "start_timestep": start_timestep[0].item(), + } + + if self.use_adversarial_loss: + result_dict["loss_dict"]["disc_total"] = loss_disc + + if train_stage == "generator": + result_dict["loss_dict"]["gen_distill"] = distill_loss * self.distill_loss_scale[K_step] + if self.use_dmd_loss: + result_dict["loss_dict"]["gen_dmd"] = self.dmd_loss_scale[K_step] * dmd_loss + if self.use_adversarial_loss: + result_dict["loss_dict"]["gen_adv"] = self.adversarial_loss_scale[K_step] * gan_loss[0] + + return result_dict + + def _distill_loss(self, student_output, teacher_output): + if self.distill_loss_type == "l2": + return torch.mean( + ((student_output - teacher_output) ** 2).reshape( + student_output.shape[0], -1 + ), + 1, + ).mean() + elif self.distill_loss_type == "l1": + return torch.mean( + torch.abs(student_output - teacher_output).reshape( + student_output.shape[0], -1 + ), + 1, + ).mean() + elif self.distill_loss_type == "lpips": + # center crop patches of size 64x64 + crop_h = (student_output.shape[2] - 64) // 2 + crop_w = (student_output.shape[3] - 64) // 2 + student_output = student_output[ + :, :, crop_h : crop_h + 64, crop_w : crop_w + 64 + ] + teacher_output = teacher_output[ + :, :, crop_h : crop_h + 64, crop_w : crop_w + 64 + ] + + decoded_student = self.vae.decode(student_output).clamp(-1, 1) + decoded_teacher = self.vae.decode(teacher_output).clamp(-1, 1) + # self.lpips = self.lpips.to(student_output.device) + return self.lpips(decoded_student, decoded_teacher).mean() + else: + raise NotImplementedError(f"Loss type {self.loss_type} not implemented") + + def _dmd_loss( + self, + student_output, + student_conditioning, + conditioning, + #unconditional_conditioning, + is_cond_unpack, + K, + K_step, + ): + """ + Compute the DMD loss + """ + + # Sample noise + noise = torch.randn_like(student_output) + + timestep = torch.randint( + 0, + self.teacher_noise_scheduler.config.num_train_timesteps, + (student_output.shape[0],), + device=student_output.device, + ) + + # Create noisy sample + noisy_student = self.teacher_noise_scheduler.add_noise( + student_output, noise, timestep + ) + + with torch.no_grad(): + cond_fake_noise_pred = self.student_denoiser.apply_model(x=noisy_student, + t = timestep, + cond=student_conditioning, + is_cond_unpack = is_cond_unpack) + + if self.student_denoiser.model_output_type == "v_prediction": + cond_fake_noise_pred = self.student_denoiser.predict_noise_from_v(x_t=noisy_student, + t=timestep, + v=cond_fake_noise_pred) + + guidance_scale = ( + torch.rand(1).to(student_output.device) + * (self.guidance_scale_max[K_step] - self.guidance_scale_min[K_step]) + + self.guidance_scale_min[K_step] + ) + + real_noise_pred = self.teacher_denoiser.apply_model( + x=noisy_student, + t = timestep, + cond=conditioning, + is_cond_unpack = is_cond_unpack, + cfg_scale = guidance_scale + ) + if self.teacher_denoiser.model_output_type == "v_prediction": + real_noise_pred = self.teacher_denoiser.predict_noise_from_v(x_t=noisy_student, + t=timestep, + v=real_noise_pred) + + fake_noise_pred = cond_fake_noise_pred + + score_real = -real_noise_pred + score_fake = -fake_noise_pred + + alpha_prod_t = self.teacher_noise_scheduler.alphas_cumprod.to( + device=student_output.device, dtype=student_output.dtype + )[timestep] + beta_prod_t = 1.0 - alpha_prod_t + + coeff = ( + (score_fake - score_real) + * beta_prod_t.view(-1, 1, 1, 1) ** 0.5 + / alpha_prod_t.view(-1, 1, 1, 1) ** 0.5 + ) + + pred_x_0_student = self._predicted_x_0( + real_noise_pred, + timestep, + noisy_student, + "epsilon", + self.sqrt_alpha_cumprod, + self.sigmas, + student_output, + ) + + weight = ( + 1.0 + / ( + (student_output - pred_x_0_student).abs().mean([1, 2, 3], keepdim=True) + + 1e-5 + ).detach() + ) + return F.mse_loss( + student_output, (student_output - weight * coeff).detach(), reduction="mean" + ) + + def _gan_loss( + self, + z, + student_output, + teacher_output, + conditioning, + is_cond_unpack, + down_intrablock_additional_residuals=None, + train_stage:Literal['generator', 'discriminator'] = 'generator', + ): + + self.disc_update_counter += 1 + + # Sample noise + noise = torch.randn_like(student_output) + + if self.use_teacher_as_real: + real = teacher_output + + else: + real = z + + # Selected timesteps + selected_timesteps = [10, 250, 500, 750] + prob = torch.tensor([0.25, 0.25, 0.25, 0.25]) + + # Sample the timesteps + idx = prob.multinomial(student_output.shape[0], replacement=True).to( + student_output.device + ) + timesteps = torch.tensor( + selected_timesteps, device=student_output.device, dtype=torch.long + )[idx] + + # Create noisy sample + noisy_fake = self.teacher_noise_scheduler.add_noise( + student_output, noise, timesteps + ) + noisy_real = self.teacher_noise_scheduler.add_noise(real, noise, timesteps) + + # Concatenate noisy samples + noisy_sample = torch.cat([noisy_fake, noisy_real], dim=0) + + # Concatenate conditionings + if conditioning is not None: + conditioning = torch.cat([conditioning, conditioning], dim=0) + + # Concatenate timesteps + timestep = torch.cat([timesteps, timesteps], dim=0) + + # Predict noise level using denoiser + denoised_sample = self.disc_backbone.apply_model(noisy_sample, timestep, conditioning, is_cond_unpack) + + denoised_sample_fake, denoised_sample_real = denoised_sample.chunk(2, dim=0) + + if self.gan_loss_type == "wgan": + # Clip weights of discriminator + for p in self.discriminator.parameters(): + p.data.clamp_(-0.01, 0.01) + if step % 2 == 0: + loss_G = -self.discriminator(denoised_sample_fake).mean() + loss_D = 0 + else: + loss_D = ( + -self.discriminator(denoised_sample_real).mean() + + self.discriminator(denoised_sample_fake.detach()).mean() + ) + loss_G = 0 + + elif self.gan_loss_type == "lsgan": + valid = torch.ones(student_output.size(0), 1, device=noise.device) + fake = torch.zeros(noise.size(0), 1, device=noise.device) + if train_stage == "generator": + loss_G = F.mse_loss( + torch.sigmoid(self.discriminator(denoised_sample_fake)), valid + ) + loss_D = 0 + else: + loss_D = 0.5 * ( + F.mse_loss( + torch.sigmoid(self.discriminator(denoised_sample_real)), valid + ) + + F.mse_loss( + torch.sigmoid( + self.discriminator(denoised_sample_fake.detach()) + ), + fake, + ) + ) + loss_G = 0 + elif self.gan_loss_type == "hinge": + if train_stage == "generator": + loss_G = -self.discriminator(denoised_sample_fake).mean() + loss_D = 0 + else: + loss_D = ( + F.relu(1.0 - self.discriminator(denoised_sample_real)).mean() + + F.relu( + 1.0 + self.discriminator(denoised_sample_fake.detach()) + ).mean() + ) + loss_G = 0 + + elif self.gan_loss_type == "non-saturating": + if train_stage == "generator": + loss_G = -torch.mean( + torch.log( + torch.sigmoid(self.discriminator(denoised_sample_fake)) + 1e-8 + ) + ) + loss_D = 0 + + else: + loss_D = -torch.mean( + torch.log( + torch.sigmoid(self.discriminator(denoised_sample_real)) + 1e-8 + ) + + torch.log( + 1 + - torch.sigmoid( + self.discriminator(denoised_sample_fake.detach()) + ) + + 1e-8 + ) + ) + loss_G = 0 + else: + if train_stage == "generator": + valid = torch.ones(student_output.size(0), 1, device=noise.device) + loss_G = F.binary_cross_entropy_with_logits( + self.discriminator(denoised_sample_fake), valid + ) + loss_D = 0 + + else: + valid = torch.ones(student_output.size(0), 1, device=noise.device) + real = F.binary_cross_entropy_with_logits( + self.discriminator(denoised_sample_real), valid + ) + fake = torch.zeros(noise.size(0), 1, device=noise.device) + fake = F.binary_cross_entropy_with_logits( + self.discriminator(denoised_sample_fake.detach()), fake + ) + loss_D = real + fake + loss_G = 0 + + return [ + loss_G, + loss_D, + ] + + def _timestep_sampling( + self, n_samples: int = 1, device="cpu", timestep_sampling="uniform" + ) -> torch.Tensor: + if timestep_sampling == "uniform": + idx = self.prob.multinomial(n_samples, replacement=True).to(device) + + return torch.tensor( + self.selected_timesteps, device=device, dtype=torch.long + )[idx] + + elif timestep_sampling == "teacher": + return torch.randint( + 0, + self.teacher_noise_scheduler.config.num_train_timesteps, + (n_samples,), + device=device, + ) + + def _scalings_for_boundary_conditions(self, timestep, sigma_data=0.5): + """ + Compute the scalings for boundary conditions + """ + c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) + c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + def _predicted_x_0( + self, + model_output, + timesteps, + sample, + prediction_type, + alphas, + sigmas, + input_sample, + ): + """ + Predict x_0 using the model output and the timesteps depending on the prediction type + """ + if prediction_type == "epsilon": + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + alpha_mask = alphas > 0 + alpha_mask = alpha_mask.reshape(-1) + alpha_mask_0 = alphas == 0 + alpha_mask_0 = alpha_mask_0.reshape(-1) + pred_x_0 = torch.zeros_like(sample) + pred_x_0[alpha_mask] = ( + sample[alpha_mask] - sigmas[alpha_mask] * model_output[alpha_mask] + ) / alphas[alpha_mask] + pred_x_0[alpha_mask_0] = input_sample[alpha_mask_0] + elif prediction_type == "v_prediction": + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError( + f"Prediction type {prediction_type} currently not supported." + ) + + return pred_x_0 + + def freeze(self): + """Freeze the model""" + self.eval() + for param in self.parameters(): + param.requires_grad = False \ No newline at end of file diff --git a/TorchJaekwon/Model/Diffusion/Distillation/FlashDiffusion/Utils.py b/TorchJaekwon/Model/Diffusion/Distillation/FlashDiffusion/Utils.py new file mode 100644 index 0000000000000000000000000000000000000000..342d9d651652bcf4362b961da186159f00e40a37 --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/Distillation/FlashDiffusion/Utils.py @@ -0,0 +1,42 @@ +from typing import Tuple + +import torch + +def gaussian_mixture(k, locs, var, mode_probs=None): + if mode_probs is None: + mode_probs = [1 / len(locs)] * len(locs) + + def _gaussian(x): + prob = [ + mode_probs[i] * torch.exp(-torch.tensor([(x - loc) ** 2 / var])) + for i, loc in enumerate(locs) + ] + # prob.append(mode_prob * torch.exp(-torch.tensor([(x) ** 2 / var]))) + return sum(prob) + + return _gaussian + +def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + return x[(...,) + (None,) * dims_to_append] + +def extract_into_tensor( + a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int, ...] +) -> torch.Tensor: + """ + Extracts values from a tensor into a new tensor using indices from another tensor. + + :param a: the tensor to extract values from. + :param t: the tensor containing the indices. + :param x_shape: the shape of the tensor to extract values into. + :return: a new tensor containing the extracted values. + """ + + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) \ No newline at end of file diff --git a/TorchJaekwon/Model/Diffusion/External/diffusers/DiffusersWrapper.py b/TorchJaekwon/Model/Diffusion/External/diffusers/DiffusersWrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..bd468b50c5def35a8eb62b740bd970dcb3b8f9df --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/External/diffusers/DiffusersWrapper.py @@ -0,0 +1,61 @@ +from typing import Optional +from torch import Tensor, device + +import torch +from tqdm import tqdm +from TorchJaekwon.Util.UtilTorch import UtilTorch +from TorchJaekwon.Model.Diffusion.DDPM import DDPM + +class DiffusersWrapper: + @staticmethod + def get_diffusers_output_type_name(ddpm_module: DDPM) -> str: + output_type_dict = { + 'v_prediction': 'v_prediction', + 'noise': 'epsilon', + 'x_start': 'sample' + } + return output_type_dict[ddpm_module.model_output_type] + + @staticmethod + def get_diffusers_scheduler_config(ddpm_module: DDPM, scheduler_args: dict): + config:dict = { + 'num_train_timesteps': ddpm_module.timesteps, + 'trained_betas': ddpm_module.betas.to('cpu'), + 'prediction_type': DiffusersWrapper.get_diffusers_output_type_name(ddpm_module), + } + config.update(scheduler_args) + return config + + @staticmethod + def infer( + ddpm_module: DDPM, + diffusers_scheduler_class, + x_shape:tuple, + cond:Optional[dict] = None, + is_cond_unpack:bool = False, + num_steps: int = 20, + scheduler_args: dict = {'timestep_spacing': 'trailing'}, + cfg_scale: float = None, + device:device = None + ) -> Tensor: + + noise_scheduler = diffusers_scheduler_class(**DiffusersWrapper.get_diffusers_scheduler_config(ddpm_module, scheduler_args)) + _, cond, additional_data_dict = ddpm_module.preprocess(x_start = None, cond=cond) + if x_shape is None: x_shape = ddpm_module.get_x_shape(cond=cond) + noise_scheduler.set_timesteps(num_steps) + model_device:device = UtilTorch.get_model_device(ddpm_module) if device is None else device + x:Tensor = torch.randn(x_shape, device = model_device) + x = x * noise_scheduler.init_noise_sigma + for t in tqdm(noise_scheduler.timesteps, desc='sample time step'): + denoiser_input = noise_scheduler.scale_model_input(x, t) + model_output = ddpm_module.apply_model(denoiser_input, + torch.full((x_shape[0],), t, device=model_device, dtype=torch.long), + cond, + is_cond_unpack, + cfg_scale = ddpm_module.cfg_scale if cfg_scale is None else cfg_scale) + x = noise_scheduler.step( model_output, t, x, return_dict=False)[0] + + return ddpm_module.postprocess(x, additional_data_dict) + + + \ No newline at end of file diff --git a/TorchJaekwon/Model/Diffusion/External/diffusers/README.md b/TorchJaekwon/Model/Diffusion/External/diffusers/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e4ff7a4c3c8903ce5dc86343e9c435b9a5a43d73 --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/External/diffusers/README.md @@ -0,0 +1 @@ +Most of the code in the this folder is sourced from the [huggingface/diffusers](https://github.com/huggingface/diffusers). \ No newline at end of file diff --git a/TorchJaekwon/Model/Diffusion/External/diffusers/configuration_utils.py b/TorchJaekwon/Model/Diffusion/External/diffusers/configuration_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f51c0060f3afb1315375230c5e25a59fc16c5422 --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/External/diffusers/configuration_utils.py @@ -0,0 +1,728 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ConfigMixin base class and utilities.""" + +import dataclasses +import functools +import importlib +import inspect +import json +import os +import re +from collections import OrderedDict +from pathlib import PosixPath +from typing import Any, Dict, Tuple, Union + +import numpy as np +''' +from huggingface_hub import create_repo, hf_hub_download +from huggingface_hub.utils import ( + EntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, + validate_hf_hub_args, +) +from requests import HTTPError + +from . import __version__ +from .utils import ( + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + DummyObject, + deprecate, + extract_commit_hash, + http_user_agent, + logging, +) + + +logger = logging.get_logger(__name__) + +_re_configuration_file = re.compile(r"config\.(.*)\.json") + +''' + +from .utils import deprecate + +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setitem__(name, value) + + +class ConfigMixin: + r""" + Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also + provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and + saving classes that inherit from [`ConfigMixin`]. + + Class attributes: + - **config_name** (`str`) -- A filename under which the config should stored when calling + [`~ConfigMixin.save_config`] (should be overridden by parent class). + - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be + overridden by subclass). + - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). + - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function + should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by + subclass). + """ + + config_name = None + ignore_for_config = [] + has_compatibles = False + + _deprecated_kwargs = [] + + def register_to_config(self, **kwargs): + if self.config_name is None: + raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") + # Special case for `kwargs` used in deprecation warning added to schedulers + # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, + # or solve in a more general way. + kwargs.pop("kwargs", None) + + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + previous_dict = dict(self._internal_dict) + internal_dict = {**self._internal_dict, **kwargs} + logger.debug(f"Updating config from {previous_dict} to {internal_dict}") + + self._internal_dict = FrozenDict(internal_dict) + + def __getattr__(self, name: str) -> Any: + """The only reason we overwrite `getattr` here is to gracefully deprecate accessing + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 + + This function is mostly copied from PyTorch's __getattr__ overwrite: + https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + """ + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) + return self._internal_dict[name] + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the + [`~ConfigMixin.from_config`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file is saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_config` + output_config_file = os.path.join(save_directory, self.config_name) + + self.to_json_file(output_config_file) + logger.info(f"Configuration saved in {output_config_file}") + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + + @classmethod + def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): + r""" + Instantiate a Python class from a config dictionary. + + Parameters: + config (`Dict[str, Any]`): + A config dictionary from which the Python class is instantiated. Make sure to only load configuration + files of compatible classes. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it is loaded) and initiate the Python class. + `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually + overwrite the same named arguments in `config`. + + Returns: + [`ModelMixin`] or [`SchedulerMixin`]: + A model or scheduler object instantiated from a config dictionary. + + Examples: + + ```python + >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler + + >>> # Download scheduler from huggingface.co and cache. + >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") + + >>> # Instantiate DDIM scheduler class with same config as DDPM + >>> scheduler = DDIMScheduler.from_config(scheduler.config) + + >>> # Instantiate PNDM scheduler class with same config as DDPM + >>> scheduler = PNDMScheduler.from_config(scheduler.config) + ``` + """ + # <===== TO BE REMOVED WITH DEPRECATION + # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated + if "pretrained_model_name_or_path" in kwargs: + config = kwargs.pop("pretrained_model_name_or_path") + + if config is None: + raise ValueError("Please make sure to provide a config as the first positional argument.") + # ======> + + if not isinstance(config, dict): + deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." + if "Scheduler" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead." + " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will" + " be removed in v1.0.0." + ) + elif "Model" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a model, please use {cls}.load_config(...) followed by" + f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary" + " instead. This functionality will be removed in v1.0.0." + ) + deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False) + config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs) + + init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) + + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + + # add possible deprecated kwargs + for deprecated_kwarg in cls._deprecated_kwargs: + if deprecated_kwarg in unused_kwargs: + init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) + + # Return model and optionally state and/or unused_kwargs + model = cls(**init_dict) + + # make sure to also save config parameters that might be used for compatible classes + # update _class_name + if "_class_name" in hidden_dict: + hidden_dict["_class_name"] = cls.__name__ + + model.register_to_config(**hidden_dict) + + # add hidden kwargs of compatible classes to unused_kwargs + unused_kwargs = {**unused_kwargs, **hidden_dict} + + if return_unused_kwargs: + return (model, unused_kwargs) + else: + return model + + @classmethod + def get_config_dict(cls, *args, **kwargs): + deprecation_message = ( + f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be" + " removed in version v1.0.0" + ) + deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) + return cls.load_config(*args, **kwargs) + + @classmethod + def load_config( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + return_unused_kwargs=False, + return_commit_hash=False, + **kwargs, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + r""" + Load a model or scheduler configuration. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with + [`~ConfigMixin.save_config`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 + of Diffusers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + return_unused_kwargs (`bool`, *optional*, defaults to `False): + Whether unused keyword arguments of the config are returned. + return_commit_hash (`bool`, *optional*, defaults to `False): + Whether the `commit_hash` of the loaded configuration are returned. + + Returns: + `dict`: + A dictionary of all the parameters stored in a JSON configuration file. + + """ + cache_dir = kwargs.pop("cache_dir", None) + local_dir = kwargs.pop("local_dir", None) + local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto") + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + _ = kwargs.pop("mirror", None) + subfolder = kwargs.pop("subfolder", None) + user_agent = kwargs.pop("user_agent", {}) + + user_agent = {**user_agent, "file_type": "config"} + user_agent = http_user_agent(user_agent) + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + if cls.config_name is None: + raise ValueError( + "`self.config_name` is not defined. Note that one should not load a config from " + "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" + ) + + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + ): + config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): + # Load from a PyTorch checkpoint + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + else: + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + config_file = hf_hub_download( + pretrained_model_name_or_path, + filename=cls.config_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + ) + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" + " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" + " token having permission to this repo with `token` or log in with `huggingface-cli login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" + " this model name. Check the model page at" + f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." + ) + except HTTPError as err: + raise EnvironmentError( + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" + " run the library in offline mode at" + " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a {cls.config_name} file" + ) + + try: + # Load config dict + config_dict = cls._dict_from_json_file(config_file) + + commit_hash = extract_commit_hash(config_file) + except (json.JSONDecodeError, UnicodeDecodeError): + raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") + + if not (return_unused_kwargs or return_commit_hash): + return config_dict + + outputs = (config_dict,) + + if return_unused_kwargs: + outputs += (kwargs,) + + if return_commit_hash: + outputs += (commit_hash,) + + return outputs + + @staticmethod + def _get_init_keys(input_class): + return set(dict(inspect.signature(input_class.__init__).parameters).keys()) + + @classmethod + def extract_init_dict(cls, config_dict, **kwargs): + # Skip keys that were not present in the original config, so default __init__ values were used + used_defaults = config_dict.get("_use_default_values", []) + config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"} + + # 0. Copy origin config dict + original_dict = dict(config_dict.items()) + + # 1. Retrieve expected config attributes from __init__ signature + expected_keys = cls._get_init_keys(cls) + expected_keys.remove("self") + # remove general kwargs if present in dict + if "kwargs" in expected_keys: + expected_keys.remove("kwargs") + # remove flax internal keys + if hasattr(cls, "_flax_internal_args"): + for arg in cls._flax_internal_args: + expected_keys.remove(arg) + + # 2. Remove attributes that cannot be expected from expected config attributes + # remove keys to be ignored + if len(cls.ignore_for_config) > 0: + expected_keys = expected_keys - set(cls.ignore_for_config) + + # load diffusers library to import compatible and original scheduler + diffusers_library = importlib.import_module(__name__.split(".")[0]) + + if cls.has_compatibles: + compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)] + else: + compatible_classes = [] + + expected_keys_comp_cls = set() + for c in compatible_classes: + expected_keys_c = cls._get_init_keys(c) + expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) + expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) + config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls} + + # remove attributes from orig class that cannot be expected + orig_cls_name = config_dict.pop("_class_name", cls.__name__) + if ( + isinstance(orig_cls_name, str) + and orig_cls_name != cls.__name__ + and hasattr(diffusers_library, orig_cls_name) + ): + orig_cls = getattr(diffusers_library, orig_cls_name) + unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys + config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig} + elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)): + raise ValueError( + "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)." + ) + + # remove private attributes + config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + + # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments + init_dict = {} + for key in expected_keys: + # if config param is passed to kwarg and is present in config dict + # it should overwrite existing config dict key + if key in kwargs and key in config_dict: + config_dict[key] = kwargs.pop(key) + + if key in kwargs: + # overwrite key + init_dict[key] = kwargs.pop(key) + elif key in config_dict: + # use value from config dict + init_dict[key] = config_dict.pop(key) + + # 4. Give nice warning if unexpected values have been passed + if len(config_dict) > 0: + logger.warning( + f"The config attributes {config_dict} were passed to {cls.__name__}, " + "but are not expected and will be ignored. Please verify your " + f"{cls.config_name} configuration file." + ) + + # 5. Give nice info if config attributes are initialized to default because they have not been passed + passed_keys = set(init_dict.keys()) + if len(expected_keys - passed_keys) > 0: + logger.info( + f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." + ) + + # 6. Define unused keyword arguments + unused_kwargs = {**config_dict, **kwargs} + + # 7. Define "hidden" config parameters that were saved for compatible classes + hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} + + return init_dict, unused_kwargs, hidden_config_dict + + @classmethod + def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @property + def config(self) -> Dict[str, Any]: + """ + Returns the config of the class as a frozen dictionary + + Returns: + `Dict[str, Any]`: Config of the class. + """ + return self._internal_dict + + def to_json_string(self) -> str: + """ + Serializes the configuration instance to a JSON string. + + Returns: + `str`: + String containing all the attributes that make up the configuration instance in JSON format. + """ + config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} + config_dict["_class_name"] = self.__class__.__name__ + config_dict["_diffusers_version"] = __version__ + + def to_json_saveable(value): + if isinstance(value, np.ndarray): + value = value.tolist() + elif isinstance(value, PosixPath): + value = str(value) + return value + + config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} + # Don't save "_ignore_files" or "_use_default_values" + config_dict.pop("_ignore_files", None) + config_dict.pop("_use_default_values", None) + + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save the configuration instance's parameters to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file to save a configuration instance's parameters. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + +def register_to_config(init): + r""" + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) + + new_kwargs = {**config_init_kwargs, **new_kwargs} + getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) + + return inner_init + + +def flax_register_to_config(cls): + original_init = cls.__init__ + + @functools.wraps(original_init) + def init(self, *args, **kwargs): + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + # Ignore private kwargs in the init. Retrieve all passed attributes + init_kwargs = dict(kwargs.items()) + + # Retrieve default values + fields = dataclasses.fields(self) + default_kwargs = {} + for field in fields: + # ignore flax specific attributes + if field.name in self._flax_internal_args: + continue + if type(field.default) == dataclasses._MISSING_TYPE: + default_kwargs[field.name] = None + else: + default_kwargs[field.name] = getattr(self, field.name) + + # Make sure init_kwargs override default kwargs + new_kwargs = {**default_kwargs, **init_kwargs} + # dtype should be part of `init_kwargs`, but not `new_kwargs` + if "dtype" in new_kwargs: + new_kwargs.pop("dtype") + + # Get positional arguments aligned with kwargs + for i, arg in enumerate(args): + name = fields[i].name + new_kwargs[name] = arg + + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) + + getattr(self, "register_to_config")(**new_kwargs) + original_init(self, *args, **kwargs) + + cls.__init__ = init + return cls + + +class LegacyConfigMixin(ConfigMixin): + r""" + A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more + pipeline-specific classes (like `DiTTransformer2DModel`). + """ + + @classmethod + def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): + # To prevent depedency import problem. + from .models.model_loading_utils import _fetch_remapped_cls_from_config + + # resolve remapping + remapped_class = _fetch_remapped_cls_from_config(config, cls) + + return remapped_class.from_config(config, return_unused_kwargs, **kwargs) diff --git a/TorchJaekwon/Model/Diffusion/External/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/TorchJaekwon/Model/Diffusion/External/diffusers/schedulers/scheduling_dpmsolver_multistep.py new file mode 100644 index 0000000000000000000000000000000000000000..7831084fd8f0bbd52294c955e416929a306e3cdc --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/External/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -0,0 +1,1059 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils.deprecation_utils import deprecate +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_lu_lambdas (`bool`, *optional*, defaults to `False`): + Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during + the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of + `lambda(t)`. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + use_karras_sigmas: Optional[bool] = False, + use_lu_lambdas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated + based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` + must be `None`, and `timestep_spacing` attribute will be ignored. + """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") + if timesteps is not None and self.config.use_lu_lambdas: + raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`") + + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.int64) + else: + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + + if self.config.use_karras_sigmas: + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + elif self.config.use_lu_lambdas: + lambdas = np.flip(log_sigmas.copy()) + lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) + sigmas = np.exp(lambdas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Lu et al. (2022).""" + + lambda_min: float = in_lambdas[-1].item() + lambda_max: float = in_lambdas[0].item() + + rho = 1.0 # 1.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = lambda_min ** (1 / rho) + max_inv_rho = lambda_max ** (1 / rho) + lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return lambdas + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 + ) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to(device=model_output.device, dtype=torch.float32) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/TorchJaekwon/Model/Diffusion/External/diffusers/schedulers/scheduling_utils.py b/TorchJaekwon/Model/Diffusion/External/diffusers/schedulers/scheduling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c30ecb36290bfcc3c068d7986272d0aa75be020b --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/External/diffusers/schedulers/scheduling_utils.py @@ -0,0 +1,194 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import os +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + +import torch +#from huggingface_hub.utils import validate_hf_hub_args + +from ..utils import BaseOutput, PushToHubMixin + + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + +# NOTE: We make this type an enum because it simplifies usage in docs and prevents +# circular imports when used for `_compatibles` within the schedulers module. +# When it's used as a type in pipelines, it really is a Union because the actual +# scheduler instance is passed in. +class KarrasDiffusionSchedulers(Enum): + DDIMScheduler = 1 + DDPMScheduler = 2 + PNDMScheduler = 3 + LMSDiscreteScheduler = 4 + EulerDiscreteScheduler = 5 + HeunDiscreteScheduler = 6 + EulerAncestralDiscreteScheduler = 7 + DPMSolverMultistepScheduler = 8 + DPMSolverSinglestepScheduler = 9 + KDPM2DiscreteScheduler = 10 + KDPM2AncestralDiscreteScheduler = 11 + DEISMultistepScheduler = 12 + UniPCMultistepScheduler = 13 + DPMSolverSDEScheduler = 14 + EDMEulerScheduler = 15 + + +AysSchedules = { + "StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24], + "StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0], + "StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13], + "StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0], + "StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0], +} + + +@dataclass +class SchedulerOutput(BaseOutput): + """ + Base class for the output of a scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.Tensor + +class SchedulerMixin(PushToHubMixin): + """ + Base class for all schedulers. + + [`SchedulerMixin`] contains common functions shared by all schedulers such as general loading and saving + functionalities. + + [`ConfigMixin`] takes care of storing the configuration attributes (like `num_train_timesteps`) that are passed to + the scheduler's `__init__` function, and the attributes can be accessed by `scheduler.config.num_train_timesteps`. + + Class attributes: + - **_compatibles** (`List[str]`) -- A list of scheduler classes that are compatible with the parent scheduler + class. Use [`~ConfigMixin.from_config`] to load a different compatible scheduler class (should be overridden + by parent class). + """ + + config_name = SCHEDULER_CONFIG_NAME + _compatibles = [] + has_compatibles = True + + @classmethod + #@validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a scheduler from a pre-defined JSON configuration file in a local directory or Hub repository. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the scheduler + configuration saved with [`~SchedulerMixin.save_pretrained`]. + subfolder (`str`, *optional*): + The subfolder location of a model file within a larger model repository on the Hub or locally. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 + of Diffusers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + """ + config, kwargs, commit_hash = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + return_commit_hash=True, + **kwargs, + ) + return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a scheduler configuration object to a directory so that it can be reloaded using the + [`~SchedulerMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes \ No newline at end of file diff --git a/TorchJaekwon/Model/Diffusion/External/diffusers/utils/__init__.py b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a12bb3634fecf7ad7dbfa281c41c9ef598ce37a --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/__init__.py @@ -0,0 +1,3 @@ +from .outputs import BaseOutput +from .hub_utils import PushToHubMixin +from .deprecation_utils import deprecate \ No newline at end of file diff --git a/TorchJaekwon/Model/Diffusion/External/diffusers/utils/deprecation_utils.py b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/deprecation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f482deddd2f46b8d2e29d5229faa0e9a21f2fd98 --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/deprecation_utils.py @@ -0,0 +1,49 @@ +import inspect +import warnings +from typing import Any, Dict, Optional, Union + +from packaging import version + + +def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): + from .. import __version__ + + deprecated_kwargs = take_from + values = () + if not isinstance(args[0], tuple): + args = (args,) + + for attribute, version_name, message in args: + if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): + raise ValueError( + f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" + f" version {__version__} is >= {version_name}" + ) + + warning = None + if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: + values += (deprecated_kwargs.pop(attribute),) + warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." + elif hasattr(deprecated_kwargs, attribute): + values += (getattr(deprecated_kwargs, attribute),) + warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." + elif deprecated_kwargs is None: + warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." + + if warning is not None: + warning = warning + " " if standard_warn else "" + warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel) + + if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: + call_frame = inspect.getouterframes(inspect.currentframe())[1] + filename = call_frame.filename + line_number = call_frame.lineno + function = call_frame.function + key, value = next(iter(deprecated_kwargs.items())) + raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") + + if len(values) == 0: + return + elif len(values) == 1: + return values[0] + return values diff --git a/TorchJaekwon/Model/Diffusion/External/diffusers/utils/hub_utils.py b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/hub_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f437920598c0fb14200fe88b612b2b63b39990d --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/hub_utils.py @@ -0,0 +1,475 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import sys +import tempfile +from pathlib import Path +from typing import Dict, List, Optional, Union +from uuid import uuid4 + +from .import_utils import ( + ENV_VARS_TRUE_VALUES, + _flax_version, + _jax_version, + _onnxruntime_version, + _torch_version, + is_flax_available, + is_onnx_available, + is_torch_available, +) +from .logging import get_logger + + +logger = get_logger(__name__) + +MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md" +SESSION_ID = uuid4().hex + + +def populate_model_card(model_card, + tags: Union[str, List[str]] = None):# -> ModelCard: + """Populates the `model_card` with library name and optional tags.""" + if model_card.data.library_name is None: + model_card.data.library_name = "diffusers" + + if tags is not None: + if isinstance(tags, str): + tags = [tags] + if model_card.data.tags is None: + model_card.data.tags = [] + for tag in tags: + model_card.data.tags.append(tag) + + return model_card + + +def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None): + """ + Extracts the commit hash from a resolved filename toward a cache file. + """ + if resolved_file is None or commit_hash is not None: + return commit_hash + resolved_file = str(Path(resolved_file).as_posix()) + search = re.search(r"snapshots/([^/]+)/", resolved_file) + if search is None: + return None + commit_hash = search.groups()[0] + return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None + + +# Old default cache path, potentially to be migrated. +# This logic was more or less taken from `transformers`, with the following differences: +# - Diffusers doesn't use custom environment variables to specify the cache path. +# - There is no need to migrate the cache format, just move the files to the new location. +hf_cache_home = os.path.expanduser( + os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) +) +old_diffusers_cache = os.path.join(hf_cache_home, "diffusers") + + +def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None: + if new_cache_dir is None: + new_cache_dir = HF_HUB_CACHE + if old_cache_dir is None: + old_cache_dir = old_diffusers_cache + + old_cache_dir = Path(old_cache_dir).expanduser() + new_cache_dir = Path(new_cache_dir).expanduser() + for old_blob_path in old_cache_dir.glob("**/blobs/*"): + if old_blob_path.is_file() and not old_blob_path.is_symlink(): + new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir) + new_blob_path.parent.mkdir(parents=True, exist_ok=True) + os.replace(old_blob_path, new_blob_path) + try: + os.symlink(new_blob_path, old_blob_path) + except OSError: + logger.warning( + "Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded." + ) + # At this point, old_cache_dir contains symlinks to the new cache (it can still be used). + +''' +cache_version_file = os.path.join(HF_HUB_CACHE, "version_diffusers_cache.txt") +if not os.path.isfile(cache_version_file): + cache_version = 0 +else: + with open(cache_version_file) as f: + try: + cache_version = int(f.read()) + except ValueError: + cache_version = 0 + +if cache_version < 1: + old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0 + if old_cache_is_not_empty: + logger.warning( + "The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your " + "existing cached models. This is a one-time operation, you can interrupt it or run it " + "later by calling `diffusers.utils.hub_utils.move_cache()`." + ) + try: + move_cache() + except Exception as e: + trace = "\n".join(traceback.format_tb(e.__traceback__)) + logger.error( + f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease " + "file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole " + "message and we will do our best to help." + ) + +if cache_version < 1: + try: + os.makedirs(HF_HUB_CACHE, exist_ok=True) + with open(cache_version_file, "w") as f: + f.write("1") + except Exception: + logger.warning( + f"There was a problem when trying to write in your cache folder ({HF_HUB_CACHE}). Please, ensure " + "the directory exists and can be written to." + ) + + +def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + + return weights_name + + +@validate_hf_hub_args +def _get_model_file( + pretrained_model_name_or_path: Union[str, Path], + *, + weights_name: str, + subfolder: Optional[str] = None, + cache_dir: Optional[str] = None, + force_download: bool = False, + proxies: Optional[Dict] = None, + resume_download: Optional[bool] = None, + local_files_only: bool = False, + token: Optional[str] = None, + user_agent: Optional[Union[Dict, str]] = None, + revision: Optional[str] = None, + commit_hash: Optional[str] = None, +): + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + return pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): + # Load from a PyTorch checkpoint + model_file = os.path.join(pretrained_model_name_or_path, weights_name) + return model_file + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + ): + model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + return model_file + else: + raise EnvironmentError( + f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." + ) + else: + # 1. First check if deprecated way of loading from branches is used + if ( + revision in DEPRECATED_REVISION_ARGS + and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME) + and version.parse(version.parse(__version__).base_version) >= version.parse("0.22.0") + ): + try: + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=_add_variant(weights_name, revision), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision or commit_hash, + ) + warnings.warn( + f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", + FutureWarning, + ) + return model_file + except: # noqa: E722 + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.", + FutureWarning, + ) + try: + # 2. Load model file as usual + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=weights_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision or commit_hash, + ) + return model_file + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `token` or log in with `huggingface-cli " + "login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {weights_name} or" + " \nCheckout your internet connection or see how to run the library in" + " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {weights_name}" + ) + + +# Adapted from +# https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976 +# Differences are in parallelization of shard downloads and checking if shards are present. + + +def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames): + shards_path = os.path.join(local_dir, subfolder) + shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames] + for shard_file in shard_filenames: + if not os.path.exists(shard_file): + raise ValueError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + + +def _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_filename, + cache_dir=None, + proxies=None, + resume_download=False, + local_files_only=False, + token=None, + user_agent=None, + revision=None, + subfolder="", +): + """ + For a given model: + + - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the + Hub + - returns the list of paths to all the shards, as well as some metadata. + + For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the + index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). + """ + if not os.path.isfile(index_filename): + raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + + original_shard_filenames = sorted(set(index["weight_map"].values())) + sharded_metadata = index["metadata"] + sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) + sharded_metadata["weight_map"] = index["weight_map"].copy() + shards_path = os.path.join(pretrained_model_name_or_path, subfolder) + + # First, let's deal with local folder. + if os.path.isdir(pretrained_model_name_or_path): + _check_if_shards_exist_locally( + pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames + ) + return pretrained_model_name_or_path, sharded_metadata + + # At this stage pretrained_model_name_or_path is a model identifier on the Hub + allow_patterns = original_shard_filenames + ignore_patterns = ["*.json", "*.md"] + if not local_files_only: + # `model_info` call must guarded with the above condition. + model_files_info = model_info(pretrained_model_name_or_path) + for shard_file in original_shard_filenames: + shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) + if not shard_file_present: + raise EnvironmentError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + + try: + # Load from URL + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, + ) + + # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so + # we don't have to catch them here. We have also dealt with EntryNotFoundError. + except HTTPError as e: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" + " again after checking your internet connection." + ) from e + + # If `local_files_only=True`, `cached_folder` may not contain all the shard files. + if local_files_only: + _check_if_shards_exist_locally( + local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames + ) + + return cached_folder, sharded_metadata + +''' +class PushToHubMixin: + """ + A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub. + """ + + def _upload_folder( + self, + working_dir: Union[str, os.PathLike], + repo_id: str, + token: Optional[str] = None, + commit_message: Optional[str] = None, + create_pr: bool = False, + ): + """ + Uploads all files in `working_dir` to `repo_id`. + """ + if commit_message is None: + if "Model" in self.__class__.__name__: + commit_message = "Upload model" + elif "Scheduler" in self.__class__.__name__: + commit_message = "Upload scheduler" + else: + commit_message = f"Upload {self.__class__.__name__}" + + logger.info(f"Uploading the files of {working_dir} to {repo_id}.") + return upload_folder( + repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr + ) + + def push_to_hub( + self, + repo_id: str, + commit_message: Optional[str] = None, + private: Optional[bool] = None, + token: Optional[str] = None, + create_pr: bool = False, + safe_serialization: bool = True, + variant: Optional[str] = None, + ) -> str: + """ + Upload model, scheduler, or pipeline files to the 🤗 Hugging Face Hub. + + Parameters: + repo_id (`str`): + The name of the repository you want to push your model, scheduler, or pipeline files to. It should + contain your organization name when pushing to an organization. `repo_id` can also be a path to a local + directory. + commit_message (`str`, *optional*): + Message to commit while pushing. Default to `"Upload {object}"`. + private (`bool`, *optional*): + Whether or not the repository created should be private. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. The token generated when running + `huggingface-cli login` (stored in `~/.huggingface`). + create_pr (`bool`, *optional*, defaults to `False`): + Whether or not to create a PR with the uploaded files or directly commit. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether or not to convert the model weights to the `safetensors` format. + variant (`str`, *optional*): + If specified, weights are saved in the format `pytorch_model..bin`. + + Examples: + + ```python + from diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="unet") + + # Push the `unet` to your namespace with the name "my-finetuned-unet". + unet.push_to_hub("my-finetuned-unet") + + # Push the `unet` to an organization with the name "my-finetuned-unet". + unet.push_to_hub("your-org/my-finetuned-unet") + ``` + """ + repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id + + # Create a new empty model card and eventually tag it + model_card = load_or_create_model_card(repo_id, token=token) + model_card = populate_model_card(model_card) + + # Save all files. + save_kwargs = {"safe_serialization": safe_serialization} + if "Scheduler" not in self.__class__.__name__: + save_kwargs.update({"variant": variant}) + + with tempfile.TemporaryDirectory() as tmpdir: + self.save_pretrained(tmpdir, **save_kwargs) + + # Update model card if needed: + model_card.save(os.path.join(tmpdir, "README.md")) + + return self._upload_folder( + tmpdir, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) diff --git a/TorchJaekwon/Model/Diffusion/External/diffusers/utils/import_utils.py b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..961a8a4f6f0803a2947476ab1e081b448f0964b4 --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/import_utils.py @@ -0,0 +1,816 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" + +import importlib.util +import operator as op +import os +import sys +from collections import OrderedDict +from itertools import chain +from types import ModuleType +from typing import Any, Union + +#from huggingface_hub.utils import is_jinja_available # noqa: F401 +from packaging import version +from packaging.version import Version, parse + +from . import logging + + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + +USE_TF = os.environ.get("USE_TF", "AUTO").upper() +USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() +USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() +USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper() +DIFFUSERS_SLOW_IMPORT = os.environ.get("DIFFUSERS_SLOW_IMPORT", "FALSE").upper() +DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES + +STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + +_torch_version = "N/A" +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available = importlib.util.find_spec("torch") is not None + if _torch_available: + try: + _torch_version = importlib_metadata.version("torch") + logger.info(f"PyTorch version {_torch_version} available.") + except importlib_metadata.PackageNotFoundError: + _torch_available = False +else: + logger.info("Disabling PyTorch because USE_TORCH is set") + _torch_available = False + +_torch_xla_available = importlib.util.find_spec("torch_xla") is not None +if _torch_xla_available: + try: + _torch_xla_version = importlib_metadata.version("torch_xla") + logger.info(f"PyTorch XLA version {_torch_xla_version} available.") + except ImportError: + _torch_xla_available = False + +# check whether torch_npu is available +_torch_npu_available = importlib.util.find_spec("torch_npu") is not None +if _torch_npu_available: + try: + _torch_npu_version = importlib_metadata.version("torch_npu") + logger.info(f"torch_npu version {_torch_npu_version} available.") + except ImportError: + _torch_npu_available = False + +_jax_version = "N/A" +_flax_version = "N/A" +if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None + if _flax_available: + try: + _jax_version = importlib_metadata.version("jax") + _flax_version = importlib_metadata.version("flax") + logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") + except importlib_metadata.PackageNotFoundError: + _flax_available = False +else: + _flax_available = False + +if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES: + _safetensors_available = importlib.util.find_spec("safetensors") is not None + if _safetensors_available: + try: + _safetensors_version = importlib_metadata.version("safetensors") + logger.info(f"Safetensors version {_safetensors_version} available.") + except importlib_metadata.PackageNotFoundError: + _safetensors_available = False +else: + logger.info("Disabling Safetensors because USE_TF is set") + _safetensors_available = False + +_transformers_available = importlib.util.find_spec("transformers") is not None +try: + _transformers_version = importlib_metadata.version("transformers") + logger.debug(f"Successfully imported transformers version {_transformers_version}") +except importlib_metadata.PackageNotFoundError: + _transformers_available = False + + +_inflect_available = importlib.util.find_spec("inflect") is not None +try: + _inflect_version = importlib_metadata.version("inflect") + logger.debug(f"Successfully imported inflect version {_inflect_version}") +except importlib_metadata.PackageNotFoundError: + _inflect_available = False + + +_unidecode_available = importlib.util.find_spec("unidecode") is not None +try: + _unidecode_version = importlib_metadata.version("unidecode") + logger.debug(f"Successfully imported unidecode version {_unidecode_version}") +except importlib_metadata.PackageNotFoundError: + _unidecode_available = False + +_onnxruntime_version = "N/A" +_onnx_available = importlib.util.find_spec("onnxruntime") is not None +if _onnx_available: + candidates = ( + "onnxruntime", + "onnxruntime-gpu", + "ort_nightly_gpu", + "onnxruntime-directml", + "onnxruntime-openvino", + "ort_nightly_directml", + "onnxruntime-rocm", + "onnxruntime-training", + ) + _onnxruntime_version = None + # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu + for pkg in candidates: + try: + _onnxruntime_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _onnx_available = _onnxruntime_version is not None + if _onnx_available: + logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") + +# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed. +# _opencv_available = importlib.util.find_spec("opencv-python") is not None +try: + candidates = ( + "opencv-python", + "opencv-contrib-python", + "opencv-python-headless", + "opencv-contrib-python-headless", + ) + _opencv_version = None + for pkg in candidates: + try: + _opencv_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _opencv_available = _opencv_version is not None + if _opencv_available: + logger.debug(f"Successfully imported cv2 version {_opencv_version}") +except importlib_metadata.PackageNotFoundError: + _opencv_available = False + +_scipy_available = importlib.util.find_spec("scipy") is not None +try: + _scipy_version = importlib_metadata.version("scipy") + logger.debug(f"Successfully imported scipy version {_scipy_version}") +except importlib_metadata.PackageNotFoundError: + _scipy_available = False + +_librosa_available = importlib.util.find_spec("librosa") is not None +try: + _librosa_version = importlib_metadata.version("librosa") + logger.debug(f"Successfully imported librosa version {_librosa_version}") +except importlib_metadata.PackageNotFoundError: + _librosa_available = False + +_accelerate_available = importlib.util.find_spec("accelerate") is not None +try: + _accelerate_version = importlib_metadata.version("accelerate") + logger.debug(f"Successfully imported accelerate version {_accelerate_version}") +except importlib_metadata.PackageNotFoundError: + _accelerate_available = False + +_xformers_available = importlib.util.find_spec("xformers") is not None +try: + _xformers_version = importlib_metadata.version("xformers") + if _torch_available: + _torch_version = importlib_metadata.version("torch") + if version.Version(_torch_version) < version.Version("1.12"): + raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12") + + logger.debug(f"Successfully imported xformers version {_xformers_version}") +except importlib_metadata.PackageNotFoundError: + _xformers_available = False + +_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None +try: + _k_diffusion_version = importlib_metadata.version("k_diffusion") + logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}") +except importlib_metadata.PackageNotFoundError: + _k_diffusion_available = False + +_note_seq_available = importlib.util.find_spec("note_seq") is not None +try: + _note_seq_version = importlib_metadata.version("note_seq") + logger.debug(f"Successfully imported note-seq version {_note_seq_version}") +except importlib_metadata.PackageNotFoundError: + _note_seq_available = False + +_wandb_available = importlib.util.find_spec("wandb") is not None +try: + _wandb_version = importlib_metadata.version("wandb") + logger.debug(f"Successfully imported wandb version {_wandb_version }") +except importlib_metadata.PackageNotFoundError: + _wandb_available = False + + +_tensorboard_available = importlib.util.find_spec("tensorboard") +try: + _tensorboard_version = importlib_metadata.version("tensorboard") + logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}") +except importlib_metadata.PackageNotFoundError: + _tensorboard_available = False + + +_compel_available = importlib.util.find_spec("compel") +try: + _compel_version = importlib_metadata.version("compel") + logger.debug(f"Successfully imported compel version {_compel_version}") +except importlib_metadata.PackageNotFoundError: + _compel_available = False + + +_ftfy_available = importlib.util.find_spec("ftfy") is not None +try: + _ftfy_version = importlib_metadata.version("ftfy") + logger.debug(f"Successfully imported ftfy version {_ftfy_version}") +except importlib_metadata.PackageNotFoundError: + _ftfy_available = False + + +_bs4_available = importlib.util.find_spec("bs4") is not None +try: + # importlib metadata under different name + _bs4_version = importlib_metadata.version("beautifulsoup4") + logger.debug(f"Successfully imported ftfy version {_bs4_version}") +except importlib_metadata.PackageNotFoundError: + _bs4_available = False + +_torchsde_available = importlib.util.find_spec("torchsde") is not None +try: + _torchsde_version = importlib_metadata.version("torchsde") + logger.debug(f"Successfully imported torchsde version {_torchsde_version}") +except importlib_metadata.PackageNotFoundError: + _torchsde_available = False + +_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None +try: + _invisible_watermark_version = importlib_metadata.version("invisible-watermark") + logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}") +except importlib_metadata.PackageNotFoundError: + _invisible_watermark_available = False + + +_peft_available = importlib.util.find_spec("peft") is not None +try: + _peft_version = importlib_metadata.version("peft") + logger.debug(f"Successfully imported peft version {_peft_version}") +except importlib_metadata.PackageNotFoundError: + _peft_available = False + +_torchvision_available = importlib.util.find_spec("torchvision") is not None +try: + _torchvision_version = importlib_metadata.version("torchvision") + logger.debug(f"Successfully imported torchvision version {_torchvision_version}") +except importlib_metadata.PackageNotFoundError: + _torchvision_available = False + +_matplotlib_available = importlib.util.find_spec("matplotlib") is not None +try: + _matplotlib_version = importlib_metadata.version("matplotlib") + logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}") +except importlib_metadata.PackageNotFoundError: + _matplotlib_available = False + +_timm_available = importlib.util.find_spec("timm") is not None +if _timm_available: + try: + _timm_version = importlib_metadata.version("timm") + logger.info(f"Timm version {_timm_version} available.") + except importlib_metadata.PackageNotFoundError: + _timm_available = False + + +def is_timm_available(): + return _timm_available + + +_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None +try: + _bitsandbytes_version = importlib_metadata.version("bitsandbytes") + logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}") +except importlib_metadata.PackageNotFoundError: + _bitsandbytes_available = False + +# Taken from `huggingface_hub`. +_is_notebook = False +try: + shell_class = get_ipython().__class__ # type: ignore # noqa: F821 + for parent_class in shell_class.__mro__: # e.g. "is subclass of" + if parent_class.__name__ == "ZMQInteractiveShell": + _is_notebook = True # Jupyter notebook, Google colab or qtconsole + break +except NameError: + pass # Probably standard Python interpreter + +_is_google_colab = "google.colab" in sys.modules + + +def is_torch_available(): + return _torch_available + + +def is_torch_xla_available(): + return _torch_xla_available + + +def is_torch_npu_available(): + return _torch_npu_available + + +def is_flax_available(): + return _flax_available + + +def is_transformers_available(): + return _transformers_available + + +def is_inflect_available(): + return _inflect_available + + +def is_unidecode_available(): + return _unidecode_available + + +def is_onnx_available(): + return _onnx_available + + +def is_opencv_available(): + return _opencv_available + + +def is_scipy_available(): + return _scipy_available + + +def is_librosa_available(): + return _librosa_available + + +def is_xformers_available(): + return _xformers_available + + +def is_accelerate_available(): + return _accelerate_available + + +def is_k_diffusion_available(): + return _k_diffusion_available + + +def is_note_seq_available(): + return _note_seq_available + + +def is_wandb_available(): + return _wandb_available + + +def is_tensorboard_available(): + return _tensorboard_available + + +def is_compel_available(): + return _compel_available + + +def is_ftfy_available(): + return _ftfy_available + + +def is_bs4_available(): + return _bs4_available + + +def is_torchsde_available(): + return _torchsde_available + + +def is_invisible_watermark_available(): + return _invisible_watermark_available + + +def is_peft_available(): + return _peft_available + + +def is_torchvision_available(): + return _torchvision_available + + +def is_matplotlib_available(): + return _matplotlib_available + + +def is_safetensors_available(): + return _safetensors_available + + +def is_bitsandbytes_available(): + return _bitsandbytes_available + + +def is_notebook(): + return _is_notebook + + +def is_google_colab(): + return _is_google_colab + + +# docstyle-ignore +FLAX_IMPORT_ERROR = """ +{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the +installation page: https://github.com/google/flax and follow the ones that match your environment. +""" + +# docstyle-ignore +INFLECT_IMPORT_ERROR = """ +{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install +inflect` +""" + +# docstyle-ignore +PYTORCH_IMPORT_ERROR = """ +{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. +""" + +# docstyle-ignore +ONNX_IMPORT_ERROR = """ +{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip +install onnxruntime` +""" + +# docstyle-ignore +OPENCV_IMPORT_ERROR = """ +{0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip +install opencv-python` +""" + +# docstyle-ignore +SCIPY_IMPORT_ERROR = """ +{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install +scipy` +""" + +# docstyle-ignore +LIBROSA_IMPORT_ERROR = """ +{0} requires the librosa library but it was not found in your environment. Checkout the instructions on the +installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment. +""" + +# docstyle-ignore +TRANSFORMERS_IMPORT_ERROR = """ +{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip +install transformers` +""" + +# docstyle-ignore +UNIDECODE_IMPORT_ERROR = """ +{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install +Unidecode` +""" + +# docstyle-ignore +K_DIFFUSION_IMPORT_ERROR = """ +{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip +install k-diffusion` +""" + +# docstyle-ignore +NOTE_SEQ_IMPORT_ERROR = """ +{0} requires the note-seq library but it was not found in your environment. You can install it with pip: `pip +install note-seq` +""" + +# docstyle-ignore +WANDB_IMPORT_ERROR = """ +{0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip +install wandb` +""" + +# docstyle-ignore +TENSORBOARD_IMPORT_ERROR = """ +{0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip +install tensorboard` +""" + + +# docstyle-ignore +COMPEL_IMPORT_ERROR = """ +{0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel` +""" + +# docstyle-ignore +BS4_IMPORT_ERROR = """ +{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip: +`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +FTFY_IMPORT_ERROR = """ +{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the +installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones +that match your environment. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +TORCHSDE_IMPORT_ERROR = """ +{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde` +""" + +# docstyle-ignore +INVISIBLE_WATERMARK_IMPORT_ERROR = """ +{0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0` +""" + +# docstyle-ignore +PEFT_IMPORT_ERROR = """ +{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install peft` +""" + +# docstyle-ignore +SAFETENSORS_IMPORT_ERROR = """ +{0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors` +""" + +# docstyle-ignore +BITSANDBYTES_IMPORT_ERROR = """ +{0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes` +""" + +BACKENDS_MAPPING = OrderedDict( + [ + ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), + ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), + ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), + ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), + ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), + ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), + ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), + ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), + ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), + ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), + ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), + ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), + ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), + ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), + ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), + ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), + ("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)), + ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), + ] +) + + +def requires_backends(obj, backends): + if not isinstance(backends, (list, tuple)): + backends = [backends] + + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError("".join(failed)) + + if name in [ + "VersatileDiffusionTextToImagePipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionDualGuidedPipeline", + "StableDiffusionImageVariationPipeline", + "UnCLIPPipeline", + ] and is_transformers_version("<", "4.25.0"): + raise ImportError( + f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install" + " --upgrade transformers \n```" + ) + + if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version( + "<", "4.26.0" + ): + raise ImportError( + f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install" + " --upgrade transformers \n```" + ) + + +class DummyObject(type): + """ + Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by + `requires_backend` each time a user tries to access any method of that class. + """ + + def __getattr__(cls, key): + if key.startswith("_") and key not in ["_load_connected_pipes", "_is_onnx"]: + return super().__getattr__(cls, key) + requires_backends(cls, cls._backends) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 +def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): + """ + Args: + Compares a library version to some requirement using a given operation. + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib_metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 +def is_torch_version(operation: str, version: str): + """ + Args: + Compares the current PyTorch version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + return compare_versions(parse(_torch_version), operation, version) + + +def is_transformers_version(operation: str, version: str): + """ + Args: + Compares the current Transformers version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _transformers_available: + return False + return compare_versions(parse(_transformers_version), operation, version) + + +def is_accelerate_version(operation: str, version: str): + """ + Args: + Compares the current Accelerate version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _accelerate_available: + return False + return compare_versions(parse(_accelerate_version), operation, version) + + +def is_peft_version(operation: str, version: str): + """ + Args: + Compares the current PEFT version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _peft_version: + return False + return compare_versions(parse(_peft_version), operation, version) + + +def is_k_diffusion_version(operation: str, version: str): + """ + Args: + Compares the current k-diffusion version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _k_diffusion_available: + return False + return compare_versions(parse(_k_diffusion_version), operation, version) + + +def get_objects_from_module(module): + """ + Args: + Returns a dict of object names and values in a module, while skipping private/internal objects + module (ModuleType): + Module to extract the objects from. + + Returns: + dict: Dictionary of object names and corresponding values + """ + + objects = {} + for name in dir(module): + if name.startswith("_"): + continue + objects[name] = getattr(module, name) + + return objects + + +class OptionalDependencyNotAvailable(BaseException): + """An error indicating that an optional dependency of Diffusers was not found in the environment.""" + + +class _LazyModule(ModuleType): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + # Very heavily inspired by optuna.integration._IntegrationModule + # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py + def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None): + super().__init__(name) + self._modules = set(import_structure.keys()) + self._class_to_module = {} + for key, values in import_structure.items(): + for value in values: + self._class_to_module[value] = key + # Needed for autocompletion in an IDE + self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) + self.__file__ = module_file + self.__spec__ = module_spec + self.__path__ = [os.path.dirname(module_file)] + self._objects = {} if extra_objects is None else extra_objects + self._name = name + self._import_structure = import_structure + + # Needed for autocompletion in an IDE + def __dir__(self): + result = super().__dir__() + # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether + # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. + for attr in self.__all__: + if attr not in result: + result.append(attr) + return result + + def __getattr__(self, name: str) -> Any: + if name in self._objects: + return self._objects[name] + if name in self._modules: + value = self._get_module(name) + elif name in self._class_to_module.keys(): + module = self._get_module(self._class_to_module[name]) + value = getattr(module, name) + else: + raise AttributeError(f"module {self.__name__} has no attribute {name}") + + setattr(self, name, value) + return value + + def _get_module(self, module_name: str): + try: + return importlib.import_module("." + module_name, self.__name__) + except Exception as e: + raise RuntimeError( + f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its" + f" traceback):\n{e}" + ) from e + + def __reduce__(self): + return (self.__class__, (self._name, self.__file__, self._import_structure)) diff --git a/TorchJaekwon/Model/Diffusion/External/diffusers/utils/logging.py b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..6f93450c410c9325aa9d0cf10262b67e4c355fda --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/logging.py @@ -0,0 +1,341 @@ +# coding=utf-8 +# Copyright 2024 Optuna, Hugging Face +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Logging utilities.""" + +import logging +import os +import sys +import threading +from logging import ( + CRITICAL, # NOQA + DEBUG, # NOQA + ERROR, # NOQA + FATAL, # NOQA + INFO, # NOQA + NOTSET, # NOQA + WARN, # NOQA + WARNING, # NOQA +) +from typing import Dict, Optional + +from tqdm import auto as tqdm_lib + + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.WARNING + +_tqdm_active = True + + +def _get_default_logging_level() -> int: + """ + If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " + f"has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + + if sys.stderr: # only if sys.stderr exists, e.g. when not using pythonw in windows + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_log_levels_dict() -> Dict[str, int]: + return log_levels + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Return a logger with the specified name. + + This function is not supposed to be directly accessed unless you are writing a custom diffusers module. + """ + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + return logging.getLogger(name) + + +def get_verbosity() -> int: + """ + Return the current level for the 🤗 Diffusers' root logger as an `int`. + + Returns: + `int`: + Logging level integers which can be one of: + + - `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - `40`: `diffusers.logging.ERROR` + - `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - `20`: `diffusers.logging.INFO` + - `10`: `diffusers.logging.DEBUG` + + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Set the verbosity level for the 🤗 Diffusers' root logger. + + Args: + verbosity (`int`): + Logging level which can be one of: + + - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - `diffusers.logging.ERROR` + - `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - `diffusers.logging.INFO` + - `diffusers.logging.DEBUG` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def set_verbosity_info() -> None: + """Set the verbosity to the `INFO` level.""" + return set_verbosity(INFO) + + +def set_verbosity_warning() -> None: + """Set the verbosity to the `WARNING` level.""" + return set_verbosity(WARNING) + + +def set_verbosity_debug() -> None: + """Set the verbosity to the `DEBUG` level.""" + return set_verbosity(DEBUG) + + +def set_verbosity_error() -> None: + """Set the verbosity to the `ERROR` level.""" + return set_verbosity(ERROR) + + +def disable_default_handler() -> None: + """Disable the default handler of the 🤗 Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of the 🤗 Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def add_handler(handler: logging.Handler) -> None: + """adds a handler to the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + """removes given handler from the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler in _get_library_root_logger().handlers + _get_library_root_logger().removeHandler(handler) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is disabled by default. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent + double logging if the root logger has been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True + + +def enable_explicit_format() -> None: + """ + Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows: + ``` + [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE + ``` + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") + handler.setFormatter(formatter) + + +def reset_format() -> None: + """ + Resets the formatting for 🤗 Diffusers' loggers. + + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + handler.setFormatter(None) + + +def warning_advice(self, *args, **kwargs) -> None: + """ + This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this + warning will not be printed + """ + no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) + if no_advisory_warnings: + return + self.warning(*args, **kwargs) + + +logging.Logger.warning_advice = warning_advice + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar() -> None: + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar() -> None: + """Disable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False diff --git a/TorchJaekwon/Model/Diffusion/External/diffusers/utils/outputs.py b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..6080a86b871aec2218471591bccac408a45d82b2 --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/outputs.py @@ -0,0 +1,137 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Generic utilities +""" + +from collections import OrderedDict +from dataclasses import fields, is_dataclass +from typing import Any, Tuple + +import numpy as np + +from .import_utils import is_torch_available, is_torch_version + + +def is_tensor(x) -> bool: + """ + Tests if `x` is a `torch.Tensor` or `np.ndarray`. + """ + if is_torch_available(): + import torch + + if isinstance(x, torch.Tensor): + return True + + return isinstance(x, np.ndarray) + + +class BaseOutput(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a + tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular + Python dictionary. + + + + You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple + first. + + + """ + + def __init_subclass__(cls) -> None: + """Register subclasses as pytree nodes. + + This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with + `static_graph=True` with modules that output `ModelOutput` subclasses. + """ + if is_torch_available(): + import torch.utils._pytree + + if is_torch_version("<", "2.2"): + torch.utils._pytree._register_pytree_node( + cls, + torch.utils._pytree._dict_flatten, + lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), + ) + else: + torch.utils._pytree.register_pytree_node( + cls, + torch.utils._pytree._dict_flatten, + lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), + ) + + def __post_init__(self) -> None: + class_fields = fields(self) + + # Safety and consistency checks + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + + if other_fields_are_none and isinstance(first_field, dict): + for key, value in first_field.items(): + self[key] = value + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __getitem__(self, k: Any) -> Any: + if isinstance(k, str): + inner_dict = dict(self.items()) + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name: Any, value: Any) -> None: + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def __reduce__(self): + if not is_dataclass(self): + return super().__reduce__() + callable, _args, *remaining = super().__reduce__() + args = tuple(getattr(self, field.name) for field in fields(self)) + return callable, args, *remaining + + def to_tuple(self) -> Tuple[Any, ...]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) diff --git a/TorchJaekwon/Model/Diffusion/External/diffusers/utils/torch_utils.py b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf75b4fad4e17e416e5a3f72a25b1f2b3702fe1 --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/External/diffusers/utils/torch_utils.py @@ -0,0 +1,148 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PyTorch utilities: Utilities related to PyTorch +""" + +from typing import List, Optional, Tuple, Union + +from . import logging +from .import_utils import is_torch_available, is_torch_version + + +if is_torch_available(): + import torch + from torch.fft import fftn, fftshift, ifftn, ifftshift + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +try: + from torch._dynamo import allow_in_graph as maybe_allow_in_graph +except (ImportError, ModuleNotFoundError): + + def maybe_allow_in_graph(cls): + return cls + + +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +): + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + logger.info( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents + + +def is_compiled_module(module) -> bool: + """Check whether the module was compiled with torch.compile()""" + if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): + return False + return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) + + +def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": + """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). + + This version of the method comes from here: + https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706 + """ + x = x_in + B, C, H, W = x.shape + + # Non-power of 2 images must be float32 + if (W & (W - 1)) != 0 or (H & (H - 1)) != 0: + x = x.to(dtype=torch.float32) + + # FFT + x_freq = fftn(x, dim=(-2, -1)) + x_freq = fftshift(x_freq, dim=(-2, -1)) + + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W), device=x.device) + + crow, ccol = H // 2, W // 2 + mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale + x_freq = x_freq * mask + + # IFFT + x_freq = ifftshift(x_freq, dim=(-2, -1)) + x_filtered = ifftn(x_freq, dim=(-2, -1)).real + + return x_filtered.to(dtype=x_in.dtype) + + +def apply_freeu( + resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs +) -> Tuple["torch.Tensor", "torch.Tensor"]: + """Applies the FreeU mechanism as introduced in https: + //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU. + + Args: + resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied. + hidden_states (`torch.Tensor`): Inputs to the underlying block. + res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block. + s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. + s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if resolution_idx == 0: + num_half_channels = hidden_states.shape[1] // 2 + hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"] + res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"]) + if resolution_idx == 1: + num_half_channels = hidden_states.shape[1] // 2 + hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"] + res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"]) + + return hidden_states, res_hidden_states diff --git a/TorchJaekwon/Model/Diffusion/Sampler/DDIM.py b/TorchJaekwon/Model/Diffusion/Sampler/DDIM.py new file mode 100644 index 0000000000000000000000000000000000000000..66fc13f21c7f7f38f66ebc96b1df43b48ba76eed --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/Sampler/DDIM.py @@ -0,0 +1,498 @@ +from typing import Literal, Optional + +from tqdm import tqdm +import numpy as np +import torch + +from TorchJaekwon.Util.UtilTorch import UtilTorch +from TorchJaekwon.Model.Diffusion.DDPM.DDPM import DDPM +from TorchJaekwon.Model.Diffusion.DDPM.DiffusionUtil import DiffusionUtil + +class DDIM(object): + def __init__(self, ddpm_model:DDPM): + self.ddpm_model:DDPM = ddpm_model + self.ddpm_num_timesteps:int = ddpm_model.timesteps + self.device:torch.device = UtilTorch.get_model_device(self.ddpm_model) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != self.device: + is_mps = self.device == "mps" or self.device == torch.device("mps") + if is_mps and attr.dtype == torch.float64: + attr = attr.to(self.device, dtype=torch.float32) + else: + attr = attr.to(self.device) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = DDIM.make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.ddpm_model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) + + self.register_buffer("betas", to_torch(self.ddpm_model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.ddpm_model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = DDIM.make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @staticmethod + def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True + ): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + @staticmethod + def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + @torch.no_grad() + def infer( + self, + x_shape:tuple = None, + cond:Optional[dict] = None, + is_cond_unpack:bool = False, + num_steps:int = 50, + batch_size:int = None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=1.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=False, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + dynamic_threshold=None, + ucg_schedule=None, + ): + + _, cond, additional_data_dict = self.ddpm_model.preprocess(x_start = None, cond=cond) + if x_shape is None: x_shape = self.ddpm_model.get_x_shape(cond=cond) + if batch_size is not None: x_shape[0] = batch_size + + self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta, verbose=verbose) + + samples, intermediates = self.ddim_sampling( + x_shape, + cond, + is_cond_unpack, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule, + ) + return self.ddpm_model.postprocess(samples, additional_data_dict) + + @torch.no_grad() + def ddim_sampling( + self, + x_shape:tuple = None, + cond:Optional[dict] = None, + is_cond_unpack:bool = False, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ucg_schedule=None, + ): + device = self.ddpm_model.betas.device + b = x_shape[0] + if x_T is None: + img = torch.randn(x_shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.ddpm_model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img + + if ucg_schedule is not None: + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + + outs = self.p_sample_ddim( + img, + ts, + index=index, + cond = cond, + is_cond_unpack = is_cond_unpack, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + img, pred_x0 = outs + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim( + self, + x, + t, + index, + cond:Optional[dict] = None, + is_cond_unpack:bool = False, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): + b, *_, device = *x.shape, x.device + + model_output = self.ddpm_model.apply_model(x, t, cond, is_cond_unpack, cfg_scale = self.ddpm_model.cfg_scale) + + if self.ddpm_model.model_output_type == "v_prediction": + e_t = self.ddpm_model.predict_noise_from_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.ddpm_model.parameterization == "eps", "not implemented" + e_t = score_corrector.modify_score( + self.ddpm_model, e_t, x, t, c, **corrector_kwargs + ) + + alphas = self.ddpm_model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.ddpm_model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.ddpm_model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.ddpm_model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + if self.ddpm_model.model_output_type != "v_prediction": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.ddpm_model.predict_x_start_from_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.ddpm_model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * DiffusionUtil.noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode( + self, + x0, + c, + t_enc, + use_original_steps=False, + return_intermediates=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + callback=None, + ): + num_reference_steps = ( + self.ddpm_num_timesteps + if use_original_steps + else self.ddim_timesteps.shape[0] + ) + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc="Encoding Image"): + t = torch.full( + (x0.shape[0],), i, device=self.ddpm_model.device, dtype=torch.long + ) + if unconditional_guidance_scale == 1.0: + noise_pred = self.ddpm_model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.ddpm_model.apply_model( + torch.cat((x_next, x_next)), + torch.cat((t, t)), + torch.cat((unconditional_conditioning, c)), + ), + 2, + ) + noise_pred = e_t_uncond + unconditional_guidance_scale * ( + noise_pred - e_t_uncond + ) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = ( + alphas_next[i].sqrt() + * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) + * noise_pred + ) + x_next = xt_weighted + weighted_noise_pred + if ( + return_intermediates + and i % (num_steps // return_intermediates) == 0 + and i < num_steps - 1 + ): + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: + callback(i) + + out = {"x_encoded": x_next, "intermediate_steps": inter_steps} + if return_intermediates: + out.update({"intermediates": intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) + + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + callback=None, + ): + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long + ) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + if callback: + callback(i) + return x_dec \ No newline at end of file diff --git a/TorchJaekwon/Model/Diffusion/Sampler/DpmSolverForDDPM.py b/TorchJaekwon/Model/Diffusion/Sampler/DpmSolverForDDPM.py new file mode 100644 index 0000000000000000000000000000000000000000..300053ddac1523f22501ecfd23d856c406d8741a --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/Sampler/DpmSolverForDDPM.py @@ -0,0 +1,45 @@ +from typing import Optional +from torch import Tensor, device + +import torch + +from TorchJaekwon.Util.UtilTorch import UtilTorch +from TorchJaekwon.Model.Diffusion.DDPM.DDPM import DDPM +from TorchJaekwon.Model.Diffusion.Sampler.dpm_solver_pytorch import DPM_Solver, NoiseScheduleVP, model_wrapper + +class DpmSolverForDDPM: + def __init__(self,ddpm_module:DDPM) -> None: + self.ddpm_module = ddpm_module + + @torch.no_grad() + def infer(self, + x_shape:tuple, + cond:Optional[dict] = None, + steps:int = 20, + order:Optional[int] = None, + ) -> Tensor: + model_device:device = UtilTorch.get_model_device(self.ddpm_module) + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.ddpm_module.betas) + model_fn = model_wrapper(self.get_model_wrapper_args(noise_schedule, cond)) + dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + x_T:Tensor = torch.randn(x_shape, device=model_device) + if order is None: order = 3 if cond is None else 2 + x = dpm_solver.sample( + x_T, + steps=steps, + order=order, + skip_type="time_uniform", + method="multistep", + ) + return x + + def get_model_wrapper_args(self, noise_schedule:NoiseScheduleVP, cond:Optional[dict] = None) -> dict: + model_type_dict:dict = {'noise':'noise', 'x_start':'x_start', 'v_prediction':'v', 'score':'score'} + args:dict = {'model': self.ddpm_module.model, 'noise_schedule':noise_schedule, 'model_type': model_type_dict[self.ddpm_module.model_output_type], 'model_kwargs':{}} + if cond is not None: + args['model_kwargs'] = cond + if self.ddpm_module.cfg_scale is not None: + args['guidance_type'] = 'classifier-free' + args['unconditional_condition'] = self.ddpm_module.get_unconditional_condition() + args['guidance_scale'] = self.ddpm_module.cfg_scale + return args \ No newline at end of file diff --git a/TorchJaekwon/Model/Diffusion/Sampler/PNDM.py b/TorchJaekwon/Model/Diffusion/Sampler/PNDM.py new file mode 100644 index 0000000000000000000000000000000000000000..91bb7caa9070582c662b8504ac3b66efa07484a6 --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/Sampler/PNDM.py @@ -0,0 +1,72 @@ +from typing import Optional +from torch import Tensor,device + +from tqdm import tqdm +import torch +from collections import deque + +from TorchJaekwon.Util.UtilTorch import UtilTorch +from TorchJaekwon.Model.Diffusion.DDPM.DDPM import DDPM + +from TorchJaekwon.Model.Diffusion.DDPM.DiffusionUtil import DiffusionUtil + +class PNDM: + #Pseudo Numerical methods for Diffusion Models on manifolds (PNDM) is by Luping Liu, Yi Ren, Zhijie Lin and Zhou Zhao + + def __init__(self, ddpm_module:DDPM) -> None: + self.ddpm_module = ddpm_module + + @torch.no_grad() + def infer(self, + x_shape:Optional[tuple], + cond:Optional[dict] = None, + is_cond_unpack:bool = False, + pndm_speedup:int = 10) -> Tensor: + _, cond, additional_data_dict = self.ddpm_module.preprocess(x_start = None, cond=cond) + if x_shape is None: x_shape = self.ddpm_module.get_x_shape(cond=cond) + total_timesteps:int = self.ddpm_module.timesteps + model_device:device = UtilTorch.get_model_device(self.ddpm_module) + x:Tensor = torch.randn(x_shape, device = model_device) + self.noise_list = deque(maxlen=4) + + for i in tqdm(reversed(range(0, total_timesteps, pndm_speedup)), desc='sample time step', total=total_timesteps // pndm_speedup): + x = self.p_sample_plms(x, torch.full((x_shape[0],), i, device=model_device, dtype=torch.long), pndm_speedup, cond, is_cond_unpack) + + return self.ddpm_module.postprocess(x, additional_data_dict) + + @torch.no_grad() + def p_sample_plms(self, x, t, interval, cond, is_cond_unpack): + """ + Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778). + """ + + noise_list = self.noise_list + noise_pred = self.ddpm_module.apply_model(x, t, cond, is_cond_unpack, self.ddpm_module.cfg_scale) + if self.ddpm_module.model_output_type == 'v_prediction': + noise_pred = self.ddpm_module.predict_noise_from_v(x, t, noise_pred) + + if len(noise_list) == 0: + x_pred = self.get_x_pred(x, noise_pred, t, interval) + noise_pred_prev = self.ddpm_module.apply_model(x_pred, torch.max(t-interval, torch.zeros_like(t)), cond, is_cond_unpack) #max(t-interval, 0) + noise_pred_prime = (noise_pred + noise_pred_prev) / 2 + elif len(noise_list) == 1: + noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2 + elif len(noise_list) == 2: + noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12 + elif len(noise_list) >= 3: + noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24 + + x_prev = self.get_x_pred(x, noise_pred_prime, t, interval) + noise_list.append(noise_pred) + + return x_prev + + def get_x_pred(self, x, noise_t, t, interval): + a_t = DiffusionUtil.extract(self.ddpm_module.alphas_cumprod, t, x.shape) + a_prev = DiffusionUtil.extract(self.ddpm_module.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape) + a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() + + x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) + x_pred = x + x_delta + + return x_pred \ No newline at end of file diff --git a/TorchJaekwon/Model/Diffusion/Sampler/dpm_solver_pytorch.py b/TorchJaekwon/Model/Diffusion/Sampler/dpm_solver_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..34d40d5894a4995b7600e4bdb625b237638331df --- /dev/null +++ b/TorchJaekwon/Model/Diffusion/Sampler/dpm_solver_pytorch.py @@ -0,0 +1,1305 @@ +import torch +import torch.nn.functional as F +import math + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise + schedule are the default settings in Yang Song's ScoreSDE: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.T = 1. + self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype) + self.total_N = self.log_alpha_array.shape[1] + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + else: + self.T = 1. + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + + def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1): + """ + For some beta schedules such as cosine schedule, the log-SNR has numerical isssues. + We clip the log-SNR near t=T within -5.1 to ensure the stability. + Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE. + """ + log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas)) + lambs = log_alphas - log_sigmas + idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda) + if idx > 0: + log_alphas = log_alphas[:-idx] + return log_alphas + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim()) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -expand_dims(sigma_t, x.dim()) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1., + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + sigma_t / sigma_s * x + - alpha_t * phi_1 * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + torch.exp(log_alpha_s1 - log_alpha_s) * x + - (sigma_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + (torch.exp(log_alpha_s1 - log_alpha_s)) * x + - (sigma_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + - 0.5 * (alpha_t * phi_1) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.)) * D1_0 + ) + else: + phi_1 = torch.expm1(h) + if solver_type == 'dpmsolver': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.)) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1. / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1. / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpmsolver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + if t.shape[0] == 1: + return xt.squeeze(0) + else: + return xt + + def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type, + method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, solver_type=solver_type, + atol=atol, rtol=rtol, return_intermediate=return_intermediate) + + def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == 'adaptive': + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order,] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError("Got wrong method {}".format(method)) + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] diff --git a/TorchJaekwon/Model/Functional.py b/TorchJaekwon/Model/Functional.py new file mode 100644 index 0000000000000000000000000000000000000000..19d692689520257934778fd9ddd67030bcfcc4f2 --- /dev/null +++ b/TorchJaekwon/Model/Functional.py @@ -0,0 +1,33 @@ +from torch import Tensor + +import torch + +class Functional: + @staticmethod + def slerp(low:Tensor, + high:Tensor, + val:float = 0.5 + ): + ''' + Spherical Linear Interpolation (Slerp) + Slerp(q_0,q_1;t) = q_0(q_0^-1 q_1)^t + = ( sin(1-t) theta ) / sin(theta) * q_0 * sin(t * theta)/sin(theta) * q_1 + where dot_product(q_0,q_1) = cos(theta) + + theta = np.arccos(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high))) + so = np.sin(theta) + return np.sin((1.0-val)*theta) / so * low + np.sin(val*theta)/so * high + ''' + assert tuple(low.shape) == tuple(high.shape), f'low shape({low.shape}) must be same as high shape({high.shape})' + feature_shape:tuple = tuple(low.shape) + # Normalize the vectors to get the directions and angles + low_1d:Tensor = low.reshape(feature_shape[0],-1) + high_1d:Tensor = high.reshape(feature_shape[0],-1) + low_norm = low_1d/torch.norm(low_1d, dim=1, keepdim=True) + high_norm = high_1d/torch.norm(high_1d, dim=1, keepdim=True) + + dot_product = (low_norm*high_norm).sum(dim = 1) + theta = torch.acos(dot_product) + so = torch.sin(theta) + res = (torch.sin((1.0-val)*theta)/so).unsqueeze(1)*low_1d + (torch.sin(val*theta)/so).unsqueeze(1) * high_1d + return res.reshape(feature_shape) \ No newline at end of file diff --git a/TorchJaekwon/Model/MultiheadAttention.py b/TorchJaekwon/Model/MultiheadAttention.py new file mode 100644 index 0000000000000000000000000000000000000000..c78151bd0dc74fb2bf140f9ea1508b5e87c43a41 --- /dev/null +++ b/TorchJaekwon/Model/MultiheadAttention.py @@ -0,0 +1,51 @@ +from typing import Optional +from torch import Tensor + +import numpy as np +import torch +import torch.nn as nn + +class MultiheadAttention(nn.Module): + def __init__(self, + query_channels:int, + key_channels:int, + value_channels:int, + total_hidden_channels:int, + out_channels:int, + num_heads: int = 8 + ) -> None: + super().__init__() + assert total_hidden_channels % num_heads == 0, f'hidden channel size({total_hidden_channels}) must be factorized by the number of heads({num_heads})' + self.num_heads:int = num_heads + self.hidden_channels:int = total_hidden_channels // num_heads + self.projection_query:nn.Module = nn.Conv1d(in_channels=query_channels, out_channels=total_hidden_channels, kernel_size=1) + self.projection_key:nn.Module = nn.Conv1d(in_channels=key_channels, out_channels=total_hidden_channels, kernel_size=1) + self.projection_value:nn.Module = nn.Conv1d(in_channels=value_channels, out_channels=total_hidden_channels, kernel_size=1) + self.projection_out:nn.Module = nn.Conv1d(in_channels=total_hidden_channels, out_channels=out_channels, kernel_size=1) + + def forward(self, + queries:Tensor, #torch.float32 [batch, query_channels, number_of_queries] + keys:Tensor, #torch.float32 [batch, key_channels, number_of_keys/values] + values:Tensor, #torch.float32 [batch, value_channels, number_of_keys/values] + mask: Optional[Tensor] = None #torch.float32 [batch, number_of_keys/values, number_of_queries] + ) -> Tensor: #torch.float32 [batch, out_channels, number_of_queries] + batch_size:int = queries.shape[0] + number_of_queries:int = queries.shape[-1] + number_of_keys_and_values:int = keys.shape[-1] + assert(keys.shape[-1] == values.shape[-1]), f'number of keys({keys.shape[-1]}) and number of values({values.shape[-1]}) must be the same' + #[batch, num_heads, hidden_channels, number_of_queries] + queries = self.projection_query(queries).view(batch_size, self.num_heads, -1, number_of_queries) + #[batch, num_heads, hidden_channels, number_of_keys/values] + keys = self.projection_key(keys).view(batch_size, self.num_heads, -1, number_of_keys_and_values) + values = self.projection_key(values).view(batch_size, self.num_heads, -1, number_of_keys_and_values) + #[batch, num_heads, number_of_keys/values,number_of_queries] martix mul of [number_of_keys/values, hidden_channels], [hidden_channels,number_of_queries] + score:Tensor = torch.matmul(keys.transpose(2, 3), queries) * (self.hidden_channels ** -0.5) + if mask is not None: + score.masked_fill_(~mask[:, None, :, :1].to(torch.bool), -np.inf) + #[batch, num_heads, number_of_keys/values,number_of_queries] + weights:Tensor = torch.softmax(score, dim=2) + #[batch, out_channels, number_of_queries] + out = self.projection_out(torch.matmul(values, weights).view(batch_size, -1, number_of_queries)) + if mask is not None: + out = out * mask[:, :1] + return out \ No newline at end of file diff --git a/TorchJaekwon/Train/AverageMeter.py b/TorchJaekwon/Train/AverageMeter.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7e20f3d88e1e9079ec6d89c434f744cab625a8 --- /dev/null +++ b/TorchJaekwon/Train/AverageMeter.py @@ -0,0 +1,17 @@ +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count \ No newline at end of file diff --git a/TorchJaekwon/Train/LogWriter/LogWriter.py b/TorchJaekwon/Train/LogWriter/LogWriter.py new file mode 100644 index 0000000000000000000000000000000000000000..b9170da827836f433caf458aa12f3fc5769f3c55 --- /dev/null +++ b/TorchJaekwon/Train/LogWriter/LogWriter.py @@ -0,0 +1,185 @@ +from typing import Dict, Optional, Union +from numpy import ndarray + +import os +import psutil +import time +import torch.nn as nn +from datetime import datetime +from datetime import timedelta +try: import wandb +except: print('Didnt import following packages: wandb') +try: from tensorboardX import SummaryWriter +except: print('Didnt import following packages: tensorboardX') + +from TorchJaekwon.Util.Util import Util +from TorchJaekwon.Util.UtilAudioSTFT import UtilAudioSTFT +from TorchJaekwon.Util.UtilTorch import UtilTorch +from TorchJaekwon.Util.UtilData import UtilData + +from HParams import HParams + +class LogWriter(): + def __init__( + self, + model:nn.Module + )->None: + + self.h_params:HParams = HParams() + self.visualizer_type:str = self.h_params.log.visualizer_type #["tensorboard","wandb"] + + self.experiment_start_time:float = time.time() + self.experiment_name:str = "[" +datetime.now().strftime('%y%m%d-%H%M%S') + "] " + self.h_params.mode.config_name if self.h_params.log.use_currenttime_on_experiment_name else self.h_params.mode.config_name + + self.log_path:dict[str,str] = {"root":"","console":"","visualizer":""} + self.set_log_path() + self.log_write_init(model=model) + + if self.visualizer_type == 'wandb': + if self.h_params.mode.train == 'resume': + try: + wandb_meta_data:dict = UtilData.yaml_load(f'''{self.log_path['root']}/wandb_meta.yaml''') + wandb.init(id=wandb_meta_data['id'], project=self.h_params.log.project_name, resume = 'must') + except: + Util.print("Failed to resume wandb. Please check the wandb_meta.yaml file", type='error') + wandb.init(project=self.h_params.log.project_name) + else: + wandb.init(project=self.h_params.log.project_name) + wandb.config = {"learning_rate": self.h_params.train.lr, "epochs": self.h_params.train.epoch, "batch_size": self.h_params.pytorch_data.dataloader['train']['batch_size'] } + watched_model = model + while not isinstance(watched_model, nn.Module): + watched_model = watched_model[list(watched_model.keys())[0]] + wandb.watch(watched_model) + wandb.run.name = self.experiment_name + wandb.run.save() + + UtilData.yaml_save(f'''{self.log_path['root']}/wandb_meta.yaml''', data={ + 'id': wandb.run.id, + 'name': wandb.run.name, + }) + elif self.visualizer_type == 'tensorboard': + self.tensorboard_writer = SummaryWriter(log_dir=self.log_path["visualizer"]) + else: + print('visualizer should be either wandb or tensorboard') + exit() + + def get_time_took(self) -> str: + time_took_second:int = int(time.time() - self.experiment_start_time) + time_took:str = str(timedelta(seconds=time_took_second)) + return time_took + + def set_log_path(self): + if self.h_params.mode.train == "resume": + self.log_path["root"] = self.h_params.mode.resume_path + else: + self.log_path["root"] = os.path.join(self.h_params.log.class_root_dir,self.experiment_name) + self.log_path["console"] = self.log_path["root"]+ "/log.txt" + self.log_path["visualizer"] = os.path.join(self.log_path["root"],"tb") + + os.makedirs(self.log_path["visualizer"],exist_ok=True) + + def print_and_log(self, log_message:str) -> None: + log_message_with_time_took:str = f"{log_message} ({self.get_time_took()} took)" + print(log_message_with_time_took) + self.log_write(log_message_with_time_took) + + def log_write_init(self, model:nn.Module) -> None: + write_mode:str = 'w' if self.h_params.mode.train != "resume" else 'a' + file = open(self.log_path["console"], write_mode) + file.write("========================================="+'\n') + file.write(f'pid: {os.getpid()} / parent_pid: {psutil.Process(os.getpid()).ppid()} \n') + file.write("========================================="+'\n') + self.log_model_parameters(file, model) + file.write("========================================="+'\n') + file.write("Epoch :" + str(self.h_params.train.epoch)+'\n') + file.write("lr :" + str(self.h_params.train.lr)+'\n') + file.write("Batch :" + str(self.h_params.pytorch_data.dataloader['train']['batch_size'])+'\n') + file.write("========================================="+'\n') + file.close() + + def log_model_parameters(self, file, model: Union[nn.Module, dict], model_name:str = ''): + if isinstance(model, nn.Module): + file.write(f'''Model {model_name} Total parameters: {format(UtilTorch.get_param_num(model)['total'], ',d')}'''+'\n') + file.write(f'''Model {model_name} Trainable parameters: {format(UtilTorch.get_param_num(model)['trainable'], ',d')}'''+'\n') + else: + for model_name in model: + self.log_model_parameters(file, model[model_name], model_name) + + def log_write(self,log_message:str)->None: + file = open(self.log_path["console"],'a') + file.write(log_message+'\n') + file.close() + + def visualizer_log( + self, + x_axis_name:str, #epoch, step, ... + x_axis_value:float, + y_axis_name:str, #metric name + y_axis_value:float + ) -> None: + + if self.visualizer_type == 'tensorboard': + self.tensorboard_writer.add_scalar(y_axis_name,y_axis_value,x_axis_value) + else: + wandb.log({y_axis_name: y_axis_value, x_axis_name: x_axis_value}) + + def plot_audio( + self, + name:str, #test case name, you could make structure by using /. ex) 'taskcase_1/test_set_1' + audio_dict:Dict[str,ndarray], #{'audio name': 1d audio array}. + global_step:int, + sample_rate:int = 16000, + is_plot_spec:bool = False, + is_plot_mel:bool = True, + mel_spec_args:Optional[dict] = None + ) -> None: + + self.plot_wav(name = name + '_audio', audio_dict = audio_dict, sample_rate=sample_rate, global_step=global_step) + if is_plot_mel: + from TorchJaekwon.Util.UtilAudioMelSpec import UtilAudioMelSpec + if mel_spec_args is None: + mel_spec_args = UtilAudioMelSpec.get_default_mel_spec_config(sample_rate=sample_rate) + mel_spec_util = UtilAudioMelSpec(**mel_spec_args) + mel_dict = dict() + for audio_name in audio_dict: + mel_dict[audio_name] = mel_spec_util.get_hifigan_mel_spec(audio=audio_dict[audio_name],return_type='ndarray') + self.plot_spec(name = name + '_mel_spec', spec_dict = mel_dict) + + + def plot_wav( + self, + name:str, #test case name, you could make structure by using /. ex) 'audio/test_set_1' + audio_dict:Dict[str,ndarray], #{'audio name': 1d audio array}, + sample_rate:int, + global_step:int + ) -> None: + + if self.visualizer_type == 'tensorboard': + for audio_name in audio_dict: + self.tensorboard_writer.add_audio(f'{name}/{audio_name}', audio_dict[audio_name], sample_rate=sample_rate, global_step=global_step) + else: + wandb_audio_list = list() + for audio_name in audio_dict: + wandb_audio_list.append(wandb.Audio(audio_dict[audio_name], caption=audio_name,sample_rate=sample_rate)) + wandb.log({name: wandb_audio_list}) + + def plot_spec(self, + name:str, #test case name, you could make structure by using /. ex) 'mel/test_set_1' + spec_dict:Dict[str,ndarray], #{'name': 2d array}, + vmin=-6.0, + vmax=1.5, + transposed=False, + global_step=0): + if self.visualizer_type == 'tensorboard': + for audio_name in spec_dict: + figure = UtilAudioSTFT.spec_to_figure(spec_dict[audio_name], vmin=vmin, vmax=vmax,transposed=transposed) + self.tensorboard_writer.add_figure(f'{name}/{audio_name}',figure,global_step=global_step) + else: + wandb_mel_list = list() + for audio_name in spec_dict: + UtilAudioSTFT.spec_to_figure(spec_dict[audio_name], vmin=vmin, vmax=vmax,transposed=transposed,save_path=f'''{self.log_path['root']}/temp_img_{audio_name}.png''') + wandb_mel_list.append(wandb.Image(f'''{self.log_path['root']}/temp_img_{audio_name}.png''', caption=audio_name)) + wandb.log({name: wandb_mel_list}) + + def log_every_epoch(self,model:nn.Module): + pass \ No newline at end of file diff --git a/TorchJaekwon/Train/Loss/CrossEntropyLossWithGaussianSmoothedLabels.py b/TorchJaekwon/Train/Loss/CrossEntropyLossWithGaussianSmoothedLabels.py new file mode 100644 index 0000000000000000000000000000000000000000..48a9cfdc534b063e648adeaa6b2a5c80113f7564 --- /dev/null +++ b/TorchJaekwon/Train/Loss/CrossEntropyLossWithGaussianSmoothedLabels.py @@ -0,0 +1,113 @@ +""" +from https://github.com/dansuh17/jdcnet-pytorch +See Also: https://github.com/pytorch/pytorch/issues/7455 +""" + +import torch +from torch import nn +import math + +class CrossEntropyLossWithGaussianSmoothedLabels(nn.Module): + def __init__(self, num_classes=722, blur_range=3): + super().__init__() + self.dim = -1 + self.num_classes = num_classes + self.blur_range = blur_range + + # pre-calculate decayed values following Gaussian distribution + # up to distance of three (== blur_range) + self.gaussian_decays = [self.gaussian_val(dist=d) for d in range(blur_range + 1)] + + @staticmethod + def gaussian_val(dist: int, sigma=1): + return math.exp(-math.pow(2, dist) / (2 * math.pow(2, sigma))) + + def forward(self, pred: torch.Tensor, target: torch.Tensor): + # pred: (b, 31, 722) + # target: (b, 31) + + pred_logit = torch.log_softmax(pred, dim=self.dim) + + # out: (b, 31, 722) + target_smoothed = self.smoothed_label(target) + + # calculate the 'cross entropy' for each of 31 features + target_loss_sum = -(pred_logit * target_smoothed).sum(dim=self.dim) + return target_loss_sum.mean() # and then take their mean + + def smoothed_label(self, target: torch.Tensor): + # out: (b, 31, 722) + target_onehot = empty_onehot(target, self.num_classes).to(target.device) + + # apply gaussian smoothing + target_smoothed = self.gaussian_blur(target, target_onehot) + target_smoothed = to_onehot(target, self.num_classes, target_smoothed) + return target_smoothed + + def gaussian_blur(self, target: torch.Tensor, one_hot: torch.Tensor): + # blur the one-hot vector with gaussian decay + with torch.no_grad(): + # Going in the reverse direction from 3 -> 0 since the value on the clamped index + # will override the previous value + # when the class index is less than 4 or greater then (num_class - 4). + for dist in range(self.blur_range, -1, -1): + one_hot = self.set_decayed_values(dist, target, one_hot) + return one_hot + + def set_decayed_values(self, dist: int, target_idx: torch.Tensor, one_hot: torch.Tensor): + # size of target_idx: (batch, num_seq) = (batch, 31) + # size of one_hot: (batch, num_seq, num_classes) = (batch, 31, 722) + for direction in [1, -1]: # apply at both positive / negative directions + # used `clamp` to prevent index from underflowing / overflowing + blur_idx = torch.clamp( + target_idx + (direction * dist), min=0, max=self.num_classes - 1) + # set decayed values at indices represented by blur_idx + decayed_val = self.gaussian_decays[dist] + one_hot = one_hot.scatter_( + dim=2, index=torch.unsqueeze(blur_idx, dim=2), value=decayed_val) + return one_hot + +def empty_onehot(target: torch.Tensor, num_classes: int): + # target_size = (batch, dim1, dim2, ...) + # one_hot size = (batch, dim1, dim2, ..., num_classes) + onehot_size = target.size() + (num_classes, ) + return torch.FloatTensor(*onehot_size).zero_() + + +def to_onehot(target: torch.Tensor, num_classes: int, src_onehot: torch.Tensor = None): + if src_onehot is None: + one_hot = empty_onehot(target, num_classes) + else: + one_hot = src_onehot + + last_dim = len(one_hot.size()) - 1 + + # creates a one hot vector provided the target indices + # and the Tensor that holds the one-hot vector + with torch.no_grad(): + one_hot = one_hot.scatter_( + dim=last_dim, index=torch.unsqueeze(target, dim=last_dim), value=1.0) + return one_hot + +def test_gaussian_blur(): + """ + Before blurring: + + tensor([[[0., 0., 0., 0., 1., 0., 0.], + [0., 0., 1., 0., 0., 0., 0.], + [1., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 1., 0., 0.]]]) + + After blurring: + + tensor([[[0.0000, 0.1353, 0.3679, 0.6065, 1.0000, 0.6065, 0.3679], + [0.3679, 0.6065, 1.0000, 0.6065, 0.3679, 0.1353, 0.0000], + [1.0000, 0.6065, 0.3679, 0.1353, 0.0000, 0.0000, 0.0000], + [0.0000, 0.1353, 0.3679, 0.6065, 1.0000, 0.6065, 0.3679]]]) + """ + loss = CrossEntropyLossWithGaussianSmoothedLabels(num_classes=7) + print(loss.smoothed_label(torch.LongTensor(1, 4) % 7)) + + +if __name__ == '__main__': + test_gaussian_blur() diff --git a/TorchJaekwon/Train/Loss/F0Loss.py b/TorchJaekwon/Train/Loss/F0Loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b96487ab992e714053ed6e438180d0371153c06f --- /dev/null +++ b/TorchJaekwon/Train/Loss/F0Loss.py @@ -0,0 +1,15 @@ +import torch.nn as nn +import torch + +class F0Loss(nn.Module): + def __init__(self): + super(F0Loss,self).__init__() + self.l1_loss = nn.L1Loss() + + def forward(self, pred_f0_not_pitch_dict:dict, gt_f0): + pred_pitch = (1-pred_f0_not_pitch_dict["not_pitch"]) + weighted_gt_f0 = torch.mul(gt_f0, pred_pitch) + weighted_pred_f0 = torch.mul(pred_f0_not_pitch_dict["f0"], pred_pitch) + loss = self.l1_loss(weighted_pred_f0, weighted_gt_f0) + return loss + \ No newline at end of file diff --git a/TorchJaekwon/Train/Loss/LossEnergy.py b/TorchJaekwon/Train/Loss/LossEnergy.py new file mode 100644 index 0000000000000000000000000000000000000000..443d424f7ceb7b8909953faf13d5493fa7f91a61 --- /dev/null +++ b/TorchJaekwon/Train/Loss/LossEnergy.py @@ -0,0 +1,16 @@ +import torch.nn as nn +import torch + +class LossEnergy(nn.Module): + def __init__(self) -> None: + super(LossEnergy,self).__init__() + self.l1_loss:nn.Module = nn.L1Loss() + + def forward(self, gt_mel:torch.Tensor, pred_mel:torch.Tensor) -> torch.Tensor: + square_gt_mel:torch.Tensor = gt_mel * gt_mel + square_pred_mel:torch.Tensor = pred_mel * pred_mel + energy_gt_mel:torch.Tensor = torch.sum(square_gt_mel,axis=2) + energy_pred_mel:torch.Tensor = torch.sum(square_pred_mel,axis=2) + loss:torch.Tensor = self.l1_loss(energy_gt_mel, energy_pred_mel) + return loss + \ No newline at end of file diff --git a/TorchJaekwon/Train/Loss/MultiScaleSpectralLoss.py b/TorchJaekwon/Train/Loss/MultiScaleSpectralLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..68c170724c8466c8fcbd37665aa4d6b8848059f6 --- /dev/null +++ b/TorchJaekwon/Train/Loss/MultiScaleSpectralLoss.py @@ -0,0 +1,25 @@ +import torch.nn as nn + +from Train.Loss.LossFunction.SingleScaleSpectralLoss import SingleScaleSpectralLoss + +class MultiScaleSpectralLoss(nn.Module): + + def __init__( + self, + n_ffts: list = [2048, 1024, 512, 256], + alpha=1.0, + overlap=0.75, + eps=1e-7): + super().__init__() + + self.losses = nn.ModuleList([SingleScaleSpectralLoss(n_fft, alpha, overlap, eps) for n_fft in n_ffts]) + + + def forward(self, x_pred, x_true): + + # cut reverbation off + x_pred = x_pred[..., : x_true.shape[-1]] + + losses = [loss(x_pred, x_true) for loss in self.losses] + + return sum(losses).sum() \ No newline at end of file diff --git a/TorchJaekwon/Train/Loss/SingleScaleSpectralLoss.py b/TorchJaekwon/Train/Loss/SingleScaleSpectralLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..7005c562c419824d5a6e850ca0d11337cb2f8044 --- /dev/null +++ b/TorchJaekwon/Train/Loss/SingleScaleSpectralLoss.py @@ -0,0 +1,25 @@ +import torch.nn as nn +from torchaudio.transforms import Spectrogram +import torch.nn.functional as F + +class SingleScaleSpectralLoss(nn.Module): + def __init__(self, n_fft, alpha=1.0, overlap=0.75, eps=1e-7): + super(SingleScaleSpectralLoss,self).__init__() + self.n_fft = n_fft + self.alpha = alpha + self.eps = eps + self.hop_length = int(n_fft * (1 - overlap)) # 25% of the length + self.spec = Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length) + + def forward(self, x_pred, x_true): + #spec = Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length) + #spec.to(x_pred.device) + + S_true = self.spec(x_true) + S_pred = self.spec(x_pred) + + linear_term = F.l1_loss(S_pred, S_true) + log_term = F.l1_loss((S_true + self.eps).log2(), (S_pred + self.eps).log2()) + + loss = linear_term + self.alpha * log_term + return loss \ No newline at end of file diff --git a/TorchJaekwon/Train/Loss/SpectralLoss.py b/TorchJaekwon/Train/Loss/SpectralLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..666f8c91df8e40ae3e8303212263db7fcd00b3cc --- /dev/null +++ b/TorchJaekwon/Train/Loss/SpectralLoss.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn +from HParams import HParams + +class SpectralLoss(nn.Module): + def __init__(self): + super(SpectralLoss,self).__init__() + self.l1_loss = nn.L1Loss() + self.weight = torch.linspace(1,0.7,60).unsqueeze(-1) + + def forward(self,gt_spectral,pred_spectral): + if self.weight.device != gt_spectral.device: + self.weight = self.weight.to(gt_spectral.device) + weighted_gt_spectral = torch.mul(gt_spectral, self.weight) + weighted_pred_spectral = torch.mul(pred_spectral, self.weight) + loss = self.l1_loss(weighted_pred_spectral, weighted_gt_spectral) + return loss \ No newline at end of file diff --git a/TorchJaekwon/Train/Optimizer/OptimizerControl.py b/TorchJaekwon/Train/Optimizer/OptimizerControl.py new file mode 100644 index 0000000000000000000000000000000000000000..235cfea6edb9e3c40ba2b8cb5059f24301c03d81 --- /dev/null +++ b/TorchJaekwon/Train/Optimizer/OptimizerControl.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn + +from HParams import HParams + +class OptimizerControl: + def __init__(self, model:nn.Module = None) -> None: + self.h_params = HParams() + self.optimizer = None + self.lr_scheduler = None + + self.scheduler_config = None + self.num_lr_scheduler_step = 0 + + if model is not None: + self.set_optimizer(model) + self.set_lr_scheduler() + + def set_optimizer(self,model:nn.Module): + optimizer_name:str = self.h_params.train.optimizer["name"] + + optimizer_config:dict = self.h_params.train.optimizer["config"] + optimizer_config["params"] = filter(lambda p: p.requires_grad, model.parameters()) + + for float_parameter in ['lr','eps']: + if float_parameter in optimizer_config: + optimizer_config[float_parameter] = float(optimizer_config[float_parameter]) + + optimizer_class = getattr(torch.optim,optimizer_name,None) + if optimizer_class is not None: + self.optimizer = optimizer_class(**optimizer_config) + + def optimizer_step(self): + self.optimizer.step() + + def optimizer_zero_grad(self): + self.optimizer.zero_grad() + + def optimizer_state_dict(self): + return self.optimizer.state_dict() + + def optimizer_load_state_dict(self, state_dict): + self.optimizer.load_state_dict(state_dict) + + def lr_scheduler_state_dict(self): + if self.lr_scheduler is not None: + return self.lr_scheduler.state_dict() + + def lr_scheduler_load_state_dict(self, state_dict): + if self.lr_scheduler is not None: + self.lr_scheduler.load_state_dict(state_dict) + + def set_lr_scheduler(self): + scheduler_dict:dict= getattr(self.h_params.train,'scheduler',None) + if scheduler_dict is not None: + self.scheduler_config = scheduler_dict + scheduler_parameter_dict = scheduler_dict['config'] + scheduler_parameter_dict['optimizer'] = self.optimizer + scheduler_class = getattr(torch.optim.lr_scheduler,scheduler_dict['name'],None) + self.lr_scheduler = scheduler_class(**scheduler_parameter_dict) + + def lr_scheduler_step(self,interval_type="step",args = None): + if self.lr_scheduler == None or (self.num_lr_scheduler_step % self.scheduler_config["frequency"]) != 0 or interval_type != self.scheduler_config["interval"]: + return + + self.lr_scheduler.step() + self.num_lr_scheduler_step += 1 + + def get_lr(self): + return self.optimizer.param_groups[0]["lr"] \ No newline at end of file diff --git a/TorchJaekwon/Train/Optimizer/OptimizerControlGan.py b/TorchJaekwon/Train/Optimizer/OptimizerControlGan.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f9631e6c5d17b0e0ca70cf29f75e9c050232d5 --- /dev/null +++ b/TorchJaekwon/Train/Optimizer/OptimizerControlGan.py @@ -0,0 +1,168 @@ +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + +from Model.ModelGan import ModelGan +from HParams import HParams + +class OptimizerControlGan: + def __init__(self, model:ModelGan = None) -> None: + self.h_params = HParams() + + self.generator_optimizer:Optimizer = None + self.discriminator_optimizer:Optimizer = None + + self.generator_lr_scheduler:_LRScheduler = None + self.discriminator_lr_scheduler:_LRScheduler = None + + self.scheduler_config:dict = self.h_params.train.scheduler + self.num_gen_lr_scheduler_step:int = 0 + self.num_dis_lr_scheduler_step:int = 0 + + if model is not None: + self.set_optimizer(model) + self.set_lr_scheduler() + + def set_optimizer(self,model:ModelGan) -> None: + generator_optimizer_name:str = self.h_params.train.optimizer["generator_name"] + + generator_optimizer_config:dict = self.h_params.train.optimizer["generator_config"] + generator_optimizer_config["params"] = model.generator.parameters() + generator_optimizer_config['lr'] = float(generator_optimizer_config['lr']) + generator_optimizer_config['eps'] = float(generator_optimizer_config['eps']) + + self.generator_optimizer = self.get_optimizer(generator_optimizer_name,generator_optimizer_config) + + discriminator_optimizer_name:str = self.h_params.train.optimizer["generator_name"] + + discriminator_optimizer_config:dict = self.h_params.train.optimizer["discriminator_config"] + discriminator_optimizer_config["params"] = model.discriminator.parameters() + discriminator_optimizer_config['lr'] = float(discriminator_optimizer_config['lr']) + discriminator_optimizer_config['eps'] = float(discriminator_optimizer_config['eps']) + + self.discriminator_optimizer = self.get_optimizer(discriminator_optimizer_name,discriminator_optimizer_config) + + def get_optimizer(self,optimizer_name:str, optimizer_config_dict:dict) -> Optimizer: + if optimizer_name == "Adam": + return torch.optim.Adam(**optimizer_config_dict) + + def optimizer_state_dict(self) -> dict: + return {"generator": self.generator_optimizer.state_dict(),"discriminator": self.discriminator_optimizer.state_dict()} + + def optimizer_load_state_dict(self, state_dict) -> None: + self.generator_optimizer.load_state_dict(state_dict['generator']) + self.discriminator_optimizer.load_state_dict(state_dict['discriminator']) + + + def lr_scheduler_state_dict(self) -> dict: + state_dict_of_lr_scheduler:dict = dict() + + if self.generator_lr_scheduler is not None: + state_dict_of_lr_scheduler['generator'] = self.generator_lr_scheduler.state_dict() + + if self.discriminator_lr_scheduler is not None: + state_dict_of_lr_scheduler['discriminator'] = self.discriminator_lr_scheduler.state_dict() + + return state_dict_of_lr_scheduler + + def lr_scheduler_load_state_dict(self, state_dict:dict) -> None: + if self.generator_lr_scheduler is not None: + self.generator_lr_scheduler.load_state_dict(state_dict['generator']) + + if self.discriminator_lr_scheduler is not None: + self.discriminator_lr_scheduler.load_state_dict(state_dict['discriminator']) + + def set_lr_scheduler(self) -> None: + pass + + def lr_scheduler_step(self,interval_type:str = "step",args = None) -> None: + self.gen_lr_scheduler_step(interval_type,args) + self.disc_lr_scheduler_step(interval_type,args) + + def gen_lr_scheduler_step(self,interval_type:str = "step",args = None) -> None: + if ((self.num_gen_lr_scheduler_step) % self.scheduler_config["generator_config"]["frequency"]) != 0: + return + if interval_type != self.scheduler_config["generator_config"]["interval"]: + return + + if self.generator_lr_scheduler is not None: + self.generator_lr_scheduler.step() + + self.num_gen_lr_scheduler_step += 1 + + def disc_lr_scheduler_step(self,interval_type:str = "step",args = None) -> None: + if ((self.num_dis_lr_scheduler_step) % self.scheduler_config["discriminator_config"]["frequency"]) != 0: + return + if interval_type != self.scheduler_config["discriminator_config"]["interval"]: + return + + if self.discriminator_lr_scheduler is not None: + self.discriminator_lr_scheduler.step() + + self.num_dis_lr_scheduler_step += 1 + + def get_lr(self) -> float: + return self.generator_optimizer.param_groups[0]["lr"] + +''' + def __init__(self,model:MelGan,h_params:HParams) -> None: + self.h_params = h_params + self.discriminator_optimizer = torch.optim.Adam( + model.discriminator.parameters(), + lr=h_params.train.lr, + weight_decay=h_params.train.weight_decay + ) + self.generator_optimizer = torch.optim.Adam( + model.generator.parameters(), + lr=h_params.train.lr, + weight_decay=h_params.train.weight_decay + ) + + self.discriminator_state_name = "discriminator" + self.generator_state_name = "generator" + self.current_state = self.generator_state_name + + self.lr_scheduler_discriminator = self.get_lr_scheduler(self.discriminator_optimizer) + + self.lr_scheduler_generator = self.get_lr_scheduler(self.generator_optimizer) + + def get_lr_scheduler(self,optimizer): + if self.h_params.train.optimizer_name == "ReduceLROnPlateau": + return torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + factor=self.h_params.train.lr_decay_gamma, + patience=self.h_params.train.lr_decay_patience, + cooldown=10, + ) + elif self.h_params.train.optimizer_name == "StepLR": + return torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.h_params.train.lr_scheduler_step_size, gamma=self.h_params.train.lr_decay_factor) + + def zero_grad(self): + self.discriminator_optimizer.zero_grad() + self.generator_optimizer.zero_grad() + + def step(self): + if self.current_state == self.discriminator_state_name: + self.discriminator_optimizer.step() + elif self.current_state == self.generator_state_name: + self.generator_optimizer.step() + + def state_dict(self): + return {"generator": self.generator_optimizer.state_dict(),"discriminator": self.discriminator_optimizer.state_dict()} + + def load_state_dict(self, state_dict_dict): + self.generator_optimizer.load_state_dict(state_dict_dict["generator"]) + self.discriminator_optimizer.load_state_dict(state_dict_dict["discriminator"]) + + def lr_scheduler_step(self,vaild_loss=None): + if self.h_params.train.optimizer_name == "ReduceLROnPlateau": + if self.current_state == self.discriminator_state_name: + self.lr_scheduler_discriminator.step(vaild_loss) + elif self.current_state == self.generator_state_name: + self.lr_scheduler_generator.step(vaild_loss) + elif self.h_params.train.optimizer_name == "StepLR": + if self.current_state == self.discriminator_state_name: + self.lr_scheduler_discriminator.step() + elif self.current_state == self.generator_state_name: + self.lr_scheduler_generator.step() +''' diff --git a/TorchJaekwon/Train/Optimizer/OptimizerControlGanLambdaLRByStep.py b/TorchJaekwon/Train/Optimizer/OptimizerControlGanLambdaLRByStep.py new file mode 100644 index 0000000000000000000000000000000000000000..156c91c7715408c6ab74bfaddbe75947c2dc743d --- /dev/null +++ b/TorchJaekwon/Train/Optimizer/OptimizerControlGanLambdaLRByStep.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn +from functools import partial + +from HParams import HParams + +from Train.Optimizer.OptimizerControlGan import OptimizerControlGan +from torch.optim.lr_scheduler import LambdaLR + +class OptimizerControlGanLambdaLRByStep(OptimizerControlGan): + + def set_lr_scheduler(self) -> None: + scheduler_config:dict = self.h_params.train.scheduler["generator_config"] + self.generator_lr_scheduler = LambdaLR( + self.generator_optimizer, + partial( get_lr_lambda, warm_up_steps=scheduler_config["warm_up_steps"], reduce_lr_steps=scheduler_config["reduce_lr_steps"]) + ) + + scheduler_config = self.h_params.train.scheduler["discriminator_config"] + self.discriminator_lr_scheduler = LambdaLR( + self.generator_optimizer, + partial( get_lr_lambda, warm_up_steps=scheduler_config["warm_up_steps"], reduce_lr_steps=scheduler_config["reduce_lr_steps"]) + ) + +def get_lr_lambda(step, warm_up_steps: int, reduce_lr_steps: int): + r"""Get lr_lambda for LambdaLR. E.g., + .. code-block: python + lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000) + from torch.optim.lr_scheduler import LambdaLR + LambdaLR(optimizer, lr_lambda) + Args: + warm_up_steps: int, steps for warm up + reduce_lr_steps: int, reduce learning rate by 0.9 every #reduce_lr_steps steps + Returns: + learning rate: float + """ + if step <= warm_up_steps: + return step / warm_up_steps + else: + return 0.9 ** (step // reduce_lr_steps) \ No newline at end of file diff --git a/TorchJaekwon/Train/Optimizer/OptimizerControlLambdaLRByStep.py b/TorchJaekwon/Train/Optimizer/OptimizerControlLambdaLRByStep.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa2254015a3ab4554f0a1a5f4fe50f21a121241 --- /dev/null +++ b/TorchJaekwon/Train/Optimizer/OptimizerControlLambdaLRByStep.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +from functools import partial + +from HParams import HParams + +from TorchJAEKWON.Train.Optimizer.OptimizerControl import OptimizerControl +from torch.optim.lr_scheduler import LambdaLR + +class OptimizerControlLambdaLRByStep(OptimizerControl): + def __init__(self,model:nn.Module = None) -> None: + super().__init__(model) + + def set_lr_scheduler(self): + self.scheduler_config = self.h_params.train.scheduler["config"] + self.lr_scheduler = LambdaLR( + self.optimizer, + partial( get_lr_lambda, warm_up_steps=self.scheduler_config["warm_up_steps"], reduce_lr_steps=self.scheduler_config["reduce_lr_steps"]) + ) + +def get_lr_lambda(step, warm_up_steps: int, reduce_lr_steps: int): + r"""Get lr_lambda for LambdaLR. E.g., + .. code-block: python + lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000) + from torch.optim.lr_scheduler import LambdaLR + LambdaLR(optimizer, lr_lambda) + Args: + warm_up_steps: int, steps for warm up + reduce_lr_steps: int, reduce learning rate by 0.9 every #reduce_lr_steps steps + Returns: + learning rate: float + """ + if step <= warm_up_steps: + return step / warm_up_steps + else: + return 0.9 ** (step // reduce_lr_steps) \ No newline at end of file diff --git a/TorchJaekwon/Train/Trainer/GANTrainer.py b/TorchJaekwon/Train/Trainer/GANTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0e35acc7c121179ff2fbd039c10c6804d5fb16 --- /dev/null +++ b/TorchJaekwon/Train/Trainer/GANTrainer.py @@ -0,0 +1,28 @@ + +from typing import Union, Literal + +import os +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from TorchJaekwon.Util.UtilData import UtilData +from TorchJaekwon.Train.Trainer.Trainer import Trainer, TrainState + +from TorchJaekwon.Train.AverageMeter import AverageMeter + +class GANTrainer(Trainer): + def __init__(self, + model_class_name:Union[str, list], # { 'generator': ['generatorname'] , 'discriminator': ['discriminatorname'] } + optimizer_class_meta_dict:dict, # { 'generator': {'name': 'AdamW', args: {'lr':1.0e-3}, model_name_list: ['gemeratorname'] } , 'discriminator': ... } + discriminator_freeze_step:int = 0, + **kwargs): + super().__init__(model_class_name=model_class_name, optimizer_class_meta_dict = optimizer_class_meta_dict, **kwargs) + self.discriminator_freeze_step:int = discriminator_freeze_step + + def backprop(self,loss): + pass + + def lr_scheduler_step(self, call_state:Literal['step','epoch'], args = None): + pass + \ No newline at end of file diff --git a/TorchJaekwon/Train/Trainer/Parallel.py b/TorchJaekwon/Train/Trainer/Parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..c3bbcf08604265015cb526adb1506c6ad8f2f235 --- /dev/null +++ b/TorchJaekwon/Train/Trainer/Parallel.py @@ -0,0 +1,190 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## ECE Department, Rutgers University +## Email: zhang.hang@rutgers.edu +## Copyright (c) 2017 +## +## This source code is licensed under the MIT-style license found in the +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +"""Encoding Data Parallel""" +import threading +import functools +import torch +from torch.autograd import Variable, Function +import torch.cuda.comm as comm +from torch.nn.parallel.data_parallel import DataParallel +from torch.nn.parallel.parallel_apply import get_a_var +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +torch_ver = torch.__version__[:3] + +__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', + 'patch_replication_callback'] + +def allreduce(*inputs): + """Cross GPU all reduce autograd operation for calculate mean and + variance in SyncBN. + """ + return AllReduce.apply(*inputs) + +class AllReduce(Function): + @staticmethod + def forward(ctx, num_inputs, *inputs): + ctx.num_inputs = num_inputs + ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] + inputs = [inputs[i:i + num_inputs] + for i in range(0, len(inputs), num_inputs)] + # sort before reduce sum + inputs = sorted(inputs, key=lambda i: i[0].get_device()) + results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) + outputs = comm.broadcast_coalesced(results, ctx.target_gpus) + return tuple([t for tensors in outputs for t in tensors]) + + @staticmethod + def backward(ctx, *inputs): + inputs = [i.data for i in inputs] + inputs = [inputs[i:i + ctx.num_inputs] + for i in range(0, len(inputs), ctx.num_inputs)] + results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) + outputs = comm.broadcast_coalesced(results, ctx.target_gpus) + return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) + +class Reduce(Function): + @staticmethod + def forward(ctx, *inputs): + ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] + inputs = sorted(inputs, key=lambda i: i.get_device()) + return comm.reduce_add(inputs) + + @staticmethod + def backward(ctx, gradOutput): + return Broadcast.apply(ctx.target_gpus, gradOutput) + + +class DataParallelModel(DataParallel): + """Implements data parallelism at the module level. + + This container parallelizes the application of the given module by + splitting the input across the specified devices by chunking in the + batch dimension. + In the forward pass, the module is replicated on each device, + and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. + Note that the outputs are not gathered, please use compatible + :class:`encoding.parallel.DataParallelCriterion`. + + The batch size should be larger than the number of GPUs used. It should + also be an integer multiple of the number of GPUs so that each chunk is + the same size (so that each GPU processes the same number of samples). + + Args: + module: module to be parallelized + device_ids: CUDA devices (default: all devices) + + Reference: + Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, + Amit Agrawal. “Context Encoding for Semantic Segmentation. + *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* + + Example:: + + >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) + >>> y = net(x) + """ + def gather(self, outputs, output_device): + return outputs + + def replicate(self, module, device_ids): + modules = super(DataParallelModel, self).replicate(module, device_ids) + return modules + + +class DataParallelCriterion(DataParallel): + """ + Calculate loss in multiple-GPUs, which balance the memory usage for + Semantic Segmentation. + + The targets are splitted across the specified devices by chunking in + the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. + + Reference: + Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, + Amit Agrawal. “Context Encoding for Semantic Segmentation. + *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* + + Example:: + + >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) + >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) + >>> y = net(x) + >>> loss = criterion(y, target) + """ + def forward(self, inputs, *targets, **kwargs): + # input should be already scatterd + # scattering the targets instead + if not self.device_ids: + return self.module(inputs, *targets, **kwargs) + targets, kwargs = self.scatter(targets, kwargs, self.device_ids) + if len(self.device_ids) == 1: + return self.module(inputs, *targets[0], **kwargs[0]) + replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) + outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) + return Reduce.apply(*outputs) / len(outputs) + + +def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): + assert len(modules) == len(inputs) + assert len(targets) == len(inputs) + if kwargs_tup: + assert len(modules) == len(kwargs_tup) + else: + kwargs_tup = ({},) * len(modules) + if devices is not None: + assert len(modules) == len(devices) + else: + devices = [None] * len(modules) + + lock = threading.Lock() + results = {} + if torch_ver != "0.3": + grad_enabled = torch.is_grad_enabled() + + def _worker(i, module, input, target, kwargs, device=None): + if type(target) is tuple: + target = target[0] + + if torch_ver != "0.3": + torch.set_grad_enabled(grad_enabled) + if device is None: + device = get_a_var(input).get_device() + try: + with torch.cuda.device(device): + output = module(*(input + target), **kwargs) + with lock: + results[i] = output + except Exception as e: + with lock: + results[i] = e + + if len(modules) > 1: + threads = [threading.Thread(target=_worker, + args=(i, module, input, target, + kwargs, device),) + for i, (module, input, target, kwargs, device) in + enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + else: + _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) + + outputs = [] + for i in range(len(inputs)): + output = results[i] + if isinstance(output, Exception): + raise output + outputs.append(output) + return outputs \ No newline at end of file diff --git a/TorchJaekwon/Train/Trainer/TemplateTrainer.py b/TorchJaekwon/Train/Trainer/TemplateTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc22efaaf4112a6b5da29ee8241e173f84cf377 --- /dev/null +++ b/TorchJaekwon/Train/Trainer/TemplateTrainer.py @@ -0,0 +1,37 @@ +#type +from torch import Tensor +#import +import numpy as np +import torch +import torch.nn as nn +#torchjaekwon import +from TorchJaekwon.Train.Trainer.Trainer import Trainer, TrainState +from TorchJaekwon.Train.AverageMeter import AverageMeter +#internal import + +class TemplateTrainer(Trainer): + + def __init__(self): + super().__init__() + + def run_step(self,data,metric,train_state): + """ + run 1 step + 1. get data + 2. use model + 3. calculate loss + 4. put the loss in metric (append) + return loss,metric + """ + data_dict = self.data_dict_to_device(data) + pred_hr_audio:Tensor = self.model['generator'][self.generator_name](audio1 = data_dict['hr_audio'], audio2 = data_dict['lr_audio']) + current_loss_dict = dict() + current_loss_dict['disc_total'], current_loss_dict['disc_mrd'], current_loss_dict['disc_mpd'] = self.discriminator_step(data_dict,pred_hr_audio['hr_audio'],train_state) + current_loss_dict['gen_total'], current_loss_dict['gen_mel'], current_loss_dict['gen_mpd'], current_loss_dict['gen_mrd'], current_loss_dict['gen_mpd_fm'], current_loss_dict['gen_mrd_fm'] = self.generator_step(data_dict,pred_hr_audio['hr_audio'],train_state) + + for loss_name in current_loss_dict: self.metric_update(metric,loss_name,current_loss_dict[loss_name],batch_size) + + return current_loss_dict["gen_total"],metric + + def log_media(self) -> None: + pass \ No newline at end of file diff --git a/TorchJaekwon/Train/Trainer/Trainer.py b/TorchJaekwon/Train/Trainer/Trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..72a2110ddde6a6957653de01766aba126a363884 --- /dev/null +++ b/TorchJaekwon/Train/Trainer/Trainer.py @@ -0,0 +1,541 @@ +#type +from typing import Dict, Union, Literal, Type +from enum import Enum,unique +#import +import os +import random +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +#torchjaekwon import +from TorchJaekwon.GetModule import GetModule +from TorchJaekwon.Data.PytorchDataLoader.PytorchDataLoader import PytorchDataLoader +from TorchJaekwon.Train.LogWriter.LogWriter import LogWriter +from TorchJaekwon.Util.UtilData import UtilData +from TorchJaekwon.Util.Util import Util +from TorchJaekwon.Train.AverageMeter import AverageMeter +#internal import +from HParams import HParams + +@unique +class TrainState(Enum): + TRAIN = "train" + VALIDATE = "valid" + TEST = "test" + +class Trainer(): + def __init__(self, + #resource + device:torch.device, + #class_meta + data_class_meta_dict:dict, + model_class_name:Union[str, list], + model_class_meta_dict:dict, + optimizer_class_meta_dict:dict, # meta_dict or {key_name: meta_dict} / meta_dict: {'name': 'Adam', 'args': {'lr': 0.0001}, model_name_list: []} + lr_scheduler_class_meta_dict:dict, + loss_class_meta:dict, + #train params + max_norm_value_for_gradient_clip:float, + #train setting + total_epoch:int, + total_step:int, + save_model_every_step:int, + do_log_every_epoch:bool, + seed: float, + seed_strict:bool, + debug_mode:bool = False, + use_torch_compile:bool = True + ) -> None: + self.h_params = HParams() + self.device:torch.device = device + + self.data_class_meta_dict:dict = data_class_meta_dict + + self.model_class_name:Union[str, list] = model_class_name + self.model_class_meta_dict:dict = model_class_meta_dict + self.model:Union[nn.Module, list, dict] = None + + self.optimizer_class_meta_dict:dict = optimizer_class_meta_dict + self.optimizer:torch.optim.Optimizer = None + self.lr_scheduler_class_meta_dict:dict = lr_scheduler_class_meta_dict + self.lr_scheduler:torch.optim.lr_scheduler = None + + self.loss_function_dict:dict = dict() + self.loss_class_meta:dict = loss_class_meta + + self.data_loader_dict:dict = {subset: None for subset in ['train','valid','test']} + + self.seed:int = seed + self.set_seeds(self.seed, seed_strict) + + self.max_norm_value_for_gradient_clip:float = max_norm_value_for_gradient_clip + + self.current_epoch:int = 1 + self.total_epoch:int = total_epoch + self.total_step:int = total_step + self.global_step:int = 0 + self.local_step:int = 0 + self.best_valid_metric:dict[str,AverageMeter] = None + self.best_valid_epoch:int = 0 + self.save_model_every_step:int = save_model_every_step + self.do_log_every_epoch:bool = do_log_every_epoch + + self.debug_mode = debug_mode + self.use_torch_compile = use_torch_compile + if debug_mode: + Util.print("debug mode is on", type='warning') + torch.autograd.set_detect_anomaly(True) + else: + Util.print("debug mode is off. \n - [off] torch.autograd.set_detect_anomaly", type='info') + if self.use_torch_compile: + Util.print("\n - [on] torch.compile", type='info') + else: + Util.print("\n - [off] torch.compile", type='warning') + + + ''' + ============================================================== + abstract method start + ============================================================== + ''' + + def run_step(self,data,metric,train_state:TrainState): + """ + run 1 step + 1. get data + 2. use model + + 3. calculate loss + current_loss_dict = self.loss_control.calculate_total_loss_by_loss_meta_dict(pred_dict=pred, target_dict=train_data_dict) + + 4. put the loss in metric (append) + for loss_name in current_loss_dict: + metric[loss_name].update(current_loss_dict[loss_name].item(),batch_size) + + return current_loss_dict["total_loss"],metric + """ + raise NotImplementedError + + def save_best_model(self,prev_best_metric, current_metric): + return None + + def update_metric(self, metric:Dict[str,AverageMeter], loss_name:str, loss:torch.Tensor, batch_size:int) -> dict: + if loss_name not in metric: + metric[loss_name] = AverageMeter() + metric[loss_name].update(loss.item(), batch_size) + return metric + + + def log_metric( + self, + metrics:Dict[str,AverageMeter], + data_size: int, + train_state=TrainState.TRAIN + )->None: + """ + log and visualizer log + """ + if train_state == TrainState.TRAIN: + x_axis_name:str = "step_global" + x_axis_value:int = self.global_step + else: + x_axis_name:str = "epoch" + x_axis_value:int = self.current_epoch + + log:str = f'Epoch ({train_state.value}): {self.current_epoch:03} ({self.local_step}/{data_size}) global_step: {self.global_step} lr: {self.get_current_lr(self.optimizer)}\n' + + for metric_name in metrics: + val:float = metrics[metric_name].avg + log += f' {metric_name}: {val:.06f}' + self.log_writer.visualizer_log( + x_axis_name=x_axis_name, + x_axis_value=x_axis_value, + y_axis_name=f'{train_state.value}/{metric_name}', + y_axis_value=val + ) + self.log_writer.print_and_log(log) + + @torch.no_grad() + def log_media(self) -> None: + pass + + ''' + ============================================================== + abstract method end + ============================================================== + ''' + def set_seeds(self, seed:float, strict=False) -> None: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if strict: + torch.backends.cudnn.deterministic = True + np.random.seed(seed) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + + def init_train(self, dataset_dict=None): + self.model = self.init_model(self.model_class_name) + self.optimizer = self.init_optimizer(self.optimizer_class_meta_dict) + if self.lr_scheduler_class_meta_dict is not None: + self.lr_scheduler = self.init_lr_scheduler(self.optimizer, self.lr_scheduler_class_meta_dict) + self.init_loss() + self.model_to_device(self.model) + + self.log_writer:LogWriter = LogWriter(model=self.model) + self.set_data_loader(dataset_dict) + + + def init_model(self, model_class_name:Union[str, list, dict]) -> None: + if isinstance(model_class_name, list): + model = dict() + for name in model_class_name: + model[name] = self.init_model(name) + elif isinstance(model_class_name, dict): + model = dict() + for name in model_class_name: + model[name] = self.init_model(model_class_name[name]) + else: + model:nn.Module = GetModule.get_model(model_class_name) + if not self.debug_mode and self.use_torch_compile: + model = torch.compile(model) + return model + + def init_optimizer(self, optimizer_class_meta_dict:dict) -> None: + optimizer_class_name = optimizer_class_meta_dict.get('name',None) + if optimizer_class_name is None: + optimizer = dict() + for key in optimizer_class_meta_dict: + optimizer[key] = self.init_optimizer(optimizer_class_meta_dict[key]) + else: + optimizer_class = getattr(torch.optim, optimizer_class_name) + model_name_list:list = optimizer_class_meta_dict.get('model_name_list', None) + if model_name_list is None: + params = self.model.parameters() + else: + params = self.get_params(self.model, model_name_list) + + optimizer_args:dict = {"params": params} + optimizer_args.update(optimizer_class_meta_dict['args']) + optimizer_args['lr'] = float(optimizer_args['lr']) + optimizer = optimizer_class(**optimizer_args) + return optimizer + + def get_params(self, + model:dict, + model_name_list:list + ) -> dict: + params = list() + for model_name in model: + if isinstance(model[model_name], nn.Module): + if model_name in model_name_list: + params += list(model[model_name].parameters()) + else: + #model[model_name] is dict + params += self.get_params(model[model_name], model_name_list) + return params + + + + def init_lr_scheduler(self, optimizer, lr_scheduler_class_meta_dict) -> None: + if isinstance(optimizer, dict): + lr_scheduler = dict() + for key in optimizer: + lr_scheduler[key] = self.init_lr_scheduler(optimizer[key], self.lr_scheduler_class_meta_dict[key]) + else: + lr_scheduler_name:str = lr_scheduler_class_meta_dict.get('name',None) + lr_scheduler_class = getattr(torch.optim.lr_scheduler, lr_scheduler_name) + lr_scheduler_args:dict = lr_scheduler_class_meta_dict['args'] + lr_scheduler_args.update({'optimizer': optimizer}) + lr_scheduler = lr_scheduler_class(**lr_scheduler_args) + return lr_scheduler + + def init_loss(self) -> None: + for loss_name in self.loss_class_meta: + loss_class: Type[torch.nn.Module] = getattr(torch.nn, self.loss_class_meta[loss_name]['class_meta']['name']) # loss_name:Literal['L1Loss'] + self.loss_function_dict[loss_name] = loss_class() + + def model_to_device(self, model:Union[nn.Module, dict], device = None) -> None: + if isinstance(model, dict): + for model_name in model: + self.model_to_device(model[model_name]) + else: + if device is None: + model = model.to(self.device) + else: + model = model.to(device) + ''' + if self.h_params.resource.multi_gpu: + from TorchJaekwon.Train.Trainer.Parallel import DataParallelModel, DataParallelCriterion + self.model = DataParallelModel(self.model) + self.model.cuda() + for loss_name in self.loss_control.loss_function_dict: + self.loss_control.loss_function_dict[loss_name] = DataParallelCriterion(self.loss_control.loss_function_dict[loss_name]) + else: + for loss_name in self.loss_function_dict: + self.loss_function_dict[loss_name] = self.loss_function_dict[loss_name].to(self.device) + if isinstance(self.model_class_name, list): + for class_name in self.model_class_name: + self.model[class_name] = self.model[class_name].to(self.device) + elif isinstance(self.model_class_name, dict): + for type_name in self.model_class_name: + for class_name in self.model_class_name[type_name]: + self.model[type_name][class_name] = self.model[type_name][class_name].to(self.device) + else: + self.model = self.model.to(self.device) + ''' + + def data_dict_to_device(self,data_dict:dict) -> dict: + for feature_name in data_dict: + if isinstance(data_dict[feature_name],dict): + data_dict[feature_name] = self.data_dict_to_device(data_dict[feature_name]) + else: + if data_dict[feature_name].dtype in [torch.int64, torch.int32]: + data_dict[feature_name] = data_dict[feature_name].to(self.device) + else: + data_dict[feature_name] = data_dict[feature_name].float().to(self.device) + return data_dict + + def set_data_loader(self,dataset_dict=None): + data_loader_getter_class:Type[PytorchDataLoader] = GetModule.get_module_class('./Data/PytorchDataLoader', self.data_class_meta_dict['name']) + data_loader_getter = data_loader_getter_class(**self.data_class_meta_dict['args']) + if dataset_dict is not None: + pytorch_data_loader_config_dict = data_loader_getter.get_pytorch_data_loader_config(dataset_dict) + self.data_loader_dict = data_loader_getter.get_pytorch_data_loaders_from_config(pytorch_data_loader_config_dict) + else: + self.data_loader_dict = data_loader_getter.get_pytorch_data_loaders() + + def fit(self) -> None: + if getattr(self.h_params.train,'check_evalstep_first',False): + print("check evaluation step first whether there is no error") + with torch.no_grad(): + valid_metric = self.run_epoch(self.data_loader_dict['valid'],TrainState.VALIDATE, metric_range = "epoch") + self.log_current_state() + + for _ in range(self.current_epoch, self.total_epoch): + self.log_writer.print_and_log(f'----------------------- Start epoch : {self.current_epoch} / {self.h_params.train.epoch} -----------------------') + self.log_writer.print_and_log(f'current best epoch: {self.best_valid_epoch}') + if self.best_valid_metric is not None: + for loss_name in self.best_valid_metric: + self.log_writer.print_and_log(f'{loss_name}: {self.best_valid_metric[loss_name].avg}') + self.log_writer.print_and_log(f'-------------------------------------------------------------------------------------------------------') + + #Train + self.log_writer.print_and_log('train_start') + self.run_epoch(self.data_loader_dict['train'],TrainState.TRAIN, metric_range = "step") + + #Valid + self.log_writer.print_and_log('valid_start') + + with torch.no_grad(): + valid_metric = self.run_epoch(self.data_loader_dict['valid'],TrainState.VALIDATE, metric_range = "epoch") + self.lr_scheduler_step(call_state='epoch') #args=valid_metric) + + self.best_valid_metric = self.save_best_model(self.best_valid_metric, valid_metric) + + if self.current_epoch > self.do_log_every_epoch and self.current_epoch % self.h_params.train.save_model_every_epoch == 0: + self.save_module(self.model, name=f"step{self.global_step}_epoch{self.current_epoch}") + self.log_current_state() + + self.current_epoch += 1 + self.log_writer.log_every_epoch(model=self.model) + + if self.global_step >= self.total_step: + break + + self.log_writer.print_and_log(f'best_epoch: {self.best_valid_epoch}') + self.log_writer.print_and_log('Training complete') + + def run_epoch(self, dataloader: DataLoader, train_state:TrainState, metric_range:str = "step") -> dict: + assert metric_range in ["step","epoch"], "metric range should be 'step' or 'epoch'" + + if train_state == TrainState.TRAIN: + self.set_model_train_valid_mode(self.model, 'train') + else: + self.set_model_train_valid_mode(self.model, 'valid') + + try: dataset_size = len(dataloader) + except: dataset_size = dataloader.dataset.__len__() + + if metric_range == "epoch": + metric = dict() + + for step,data in enumerate(dataloader): + + if metric_range == "step": + metric = dict() + + if step >= len(dataloader): + break + + self.local_step = step + loss,metric = self.run_step(data,metric,train_state) + + if isinstance(loss, torch.Tensor) and torch.isnan(loss).any(): + path = os.path.join(self.log_writer.log_path["root"],f'nan_loss_data_{self.global_step}.pkl') + UtilData.pickle_save(path,data) + self.save_module(self.model, name=f"nan_loss_step{self.global_step}") + self.save_checkpoint(f"nan_loss_step{self.global_step}.pth") + raise ValueError(f'loss is nan at step {self.global_step}') + + if train_state == TrainState.TRAIN: + self.backprop(loss) + self.lr_scheduler_step(call_state='step') + + if self.local_step % self.h_params.log.log_every_local_step == 0: + self.log_metric(metrics=metric,data_size=dataset_size) + + if self.save_model_every_step is not None and self.global_step % self.save_model_every_step == 0 and not self.global_step == 0: + self.save_module(self.model, name=f"step{self.global_step}") + self.log_current_state() + + self.global_step += 1 + if self.global_step >= self.total_step: + return metric + + if train_state == TrainState.VALIDATE or train_state == TrainState.TEST: + self.log_metric(metrics=metric,data_size=dataset_size,train_state=train_state) + + return metric + + def log_current_state(self,train_state:TrainState = None, is_log_media:bool = True) -> None: + self.log_writer.print_and_log(f'-------------------------------------------------------------------------------------------------------') + self.log_writer.print_and_log(f'save current state') + self.log_writer.print_and_log(f'-------------------------------------------------------------------------------------------------------') + + if train_state == TrainState.TRAIN or train_state == None: + self.save_checkpoint() + self.save_checkpoint("train_checkpoint_backup.pth") + if is_log_media: + with torch.no_grad(): + self.log_media() + + self.log_writer.print_and_log(f'-------------------------------------------------------------------------------------------------------') + self.log_writer.print_and_log(f'-------------------------------------------------------------------------------------------------------') + + def backprop(self,loss): + if self.max_norm_value_for_gradient_clip is not None: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_norm_value_for_gradient_clip) + + if getattr(self.h_params.train,'optimizer_step_unit',1) == 1: + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + else: + loss.backward() + if (self.global_step + 1) % self.h_params.train.optimizer_step_unit == 0: + self.optimizer.step() + self.optimizer.zero_grad() + + def set_model_train_valid_mode(self, model, mode: Literal['train','valid']): + if isinstance(model, dict): + for model_name in model: + self.set_model_train_valid_mode(model[model_name], mode) + else: + if mode == 'train': + model.train() + else: + model.eval() + model.zero_grad() + + def metric_update(self, metric:Dict[str, AverageMeter], loss_name:str, loss:torch.Tensor, batch_size:int) -> dict: + if loss_name not in metric: + metric[loss_name] = AverageMeter() + metric[loss_name].update(loss.item(),batch_size) + return metric + + def save_module(self, model, model_name = '', name = 'pretrained_best_epoch'): + if isinstance(model, dict): + for model_type in model: + self.save_module(model[model_type], model_name + f'{model_type}_', name) + else: + path = os.path.join(self.log_writer.log_path["root"],f'{model_name}{name}.pth') + torch.save(model.state_dict() if not self.h_params.resource.multi_gpu else model.module.state_dict(), path) + + def load_module(self,name = 'pretrained_best_epoch'): + path = os.path.join(self.log_writer.log_path["root"],f'{name}.pth') + best_model_load = torch.load(path) + self.model.load_state_dict(best_model_load) + + def get_current_lr(self, optimizer:Union[ dict, torch.optim.Optimizer]): + if isinstance(optimizer, dict): + return self.get_current_lr(optimizer[list(optimizer.keys())[0]]) + else: + return optimizer.param_groups[0]['lr'] + + def lr_scheduler_step(self, call_state:Literal['step','epoch'], args = None): + if self.lr_scheduler is None: + return + if self.h_params.train.scheduler['interval'] == call_state: + if args is not None: + if isinstance(self.lr_scheduler, dict): + for key in self.lr_scheduler: + self.lr_scheduler[key].step(**args) + else: + self.lr_scheduler.step(**args) + else: + if isinstance(self.lr_scheduler, dict): + for key in self.lr_scheduler: + self.lr_scheduler[key].step() + else: + self.lr_scheduler.step() + + def save_checkpoint(self,save_name:str = 'train_checkpoint.pth'): + train_state = { + 'epoch': self.current_epoch, + 'step': self.global_step, + 'seed': self.seed, + 'model': self.get_state_dict(self.model), + 'optimizers': self.get_state_dict(self.optimizer), + 'best_metric': self.best_valid_metric, + 'best_model_epoch' : self.best_valid_epoch, + } + + if self.lr_scheduler is not None: + train_state['lr_scheduler'] = self.get_state_dict(self.lr_scheduler) + + path = os.path.join(self.log_writer.log_path["root"],save_name) + self.log_writer.print_and_log(save_name) + torch.save(train_state,path) + + def get_state_dict(self, module:Union[dict, nn.Module]) -> Union[dict, nn.Module]: + if hasattr(module, 'state_dict'): + return module.state_dict() + elif isinstance(module, dict): + state_dict = dict() + for key in module: + state_dict[key] = self.get_state_dict(module[key]) + return state_dict + else: + raise ValueError(f'Cannot get state_dict from {module}') + + def load_state_dict(self, module:Union[dict, nn.Module], state_dict:dict) -> Union[dict, nn.Module]: + if hasattr(module, 'load_state_dict'): + module.load_state_dict(state_dict) + return module + elif isinstance(module, dict): + for key in module: + module[key] = self.load_state_dict(module[key], state_dict[key]) + return module + else: + raise ValueError(f'Cannot load state_dict to {module}') + + def load_train(self, filename:str) -> None: + self.log_writer.print_and_log(f'load train from {filename}') + cpt:dict = torch.load(filename,map_location='cpu') + self.seed = cpt['seed'] + self.set_seeds(self.h_params.train.seed_strict) + self.current_epoch = cpt['epoch'] + self.global_step = cpt['step'] + + self.model_to_device(self.model, torch.device('cpu')) + self.model = self.load_state_dict(self.model, cpt['model']) + self.model_to_device(self.model) + + self.optimizer = self.load_state_dict(self.optimizer, cpt['optimizers']) + if self.lr_scheduler is not None: + self.lr_scheduler = self.load_state_dict(self.lr_scheduler, cpt['lr_scheduler']) + self.best_valid_result = cpt['best_metric'] + self.best_valid_epoch = cpt['best_model_epoch'] \ No newline at end of file diff --git a/TorchJaekwon/Util/External/RVC/AudioSlicer.py b/TorchJaekwon/Util/External/RVC/AudioSlicer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a17230301c6acb06692d2767aa1a1c922a845a4 --- /dev/null +++ b/TorchJaekwon/Util/External/RVC/AudioSlicer.py @@ -0,0 +1,261 @@ +import numpy as np + +# This function is obtained from librosa. +def get_rms( + y, + frame_length=2048, + hop_length=512, + pad_mode="constant", +): + padding = (int(frame_length // 2), int(frame_length // 2)) + y = np.pad(y, padding, mode=pad_mode) + + axis = -1 + # put our new within-frame axis at the end for now + out_strides = y.strides + tuple([y.strides[axis]]) + # Reduce the shape on the framing axis + x_shape_trimmed = list(y.shape) + x_shape_trimmed[axis] -= frame_length - 1 + out_shape = tuple(x_shape_trimmed) + tuple([frame_length]) + xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides) + if axis < 0: + target_axis = axis - 1 + else: + target_axis = axis + 1 + xw = np.moveaxis(xw, -1, target_axis) + # Downsample along the target axis + slices = [slice(None)] * xw.ndim + slices[axis] = slice(0, None, hop_length) + x = xw[tuple(slices)] + + # Calculate power + power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True) + + return np.sqrt(power) + + +class AudioSlicer: + def __init__( + self, + sr: int, + threshold: float = -42, + min_length: int = 1500, + min_interval: int = 400, + hop_size: int = 15, + max_sil_kept: int = 500, + ): + if not min_length >= min_interval >= hop_size: + raise ValueError( + "The following condition must be satisfied: min_length >= min_interval >= hop_size" + ) + if not max_sil_kept >= hop_size: + raise ValueError( + "The following condition must be satisfied: max_sil_kept >= hop_size" + ) + min_interval = sr * min_interval / 1000 + self.threshold = 10 ** (threshold / 20.0) + self.hop_size = round(sr * hop_size / 1000) + self.win_size = min(round(min_interval), 4 * self.hop_size) + self.min_length = round(sr * min_length / 1000 / self.hop_size) + self.min_interval = round(min_interval / self.hop_size) + self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) + + def _apply_slice(self, waveform, begin, end): + if len(waveform.shape) > 1: + return waveform[ + :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size) + ] + else: + return waveform[ + begin * self.hop_size : min(waveform.shape[0], end * self.hop_size) + ] + + # @timeit + def slice(self, + waveform: np.ndarray #[time] + ): + if len(waveform.shape) > 1: + samples = waveform.mean(axis=0) + else: + samples = waveform + if samples.shape[0] <= self.min_length: + return [waveform] + rms_list = get_rms( + y=samples, frame_length=self.win_size, hop_length=self.hop_size + ).squeeze(0) + sil_tags = [] + silence_start = None + clip_start = 0 + for i, rms in enumerate(rms_list): + # Keep looping while frame is silent. + if rms < self.threshold: + # Record start of silent frames. + if silence_start is None: + silence_start = i + continue + # Keep looping while frame is not silent and silence start has not been recorded. + if silence_start is None: + continue + # Clear recorded silence start if interval is not enough or clip is too short + is_leading_silence = silence_start == 0 and i > self.max_sil_kept + need_slice_middle = ( + i - silence_start >= self.min_interval + and i - clip_start >= self.min_length + ) + if not is_leading_silence and not need_slice_middle: + silence_start = None + continue + # Need slicing. Record the range of silent frames to be removed. + if i - silence_start <= self.max_sil_kept: + pos = rms_list[silence_start : i + 1].argmin() + silence_start + if silence_start == 0: + sil_tags.append((0, pos)) + else: + sil_tags.append((pos, pos)) + clip_start = pos + elif i - silence_start <= self.max_sil_kept * 2: + pos = rms_list[ + i - self.max_sil_kept : silence_start + self.max_sil_kept + 1 + ].argmin() + pos += i - self.max_sil_kept + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + clip_start = pos_r + else: + sil_tags.append((min(pos_l, pos), max(pos_r, pos))) + clip_start = max(pos_r, pos) + else: + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + else: + sil_tags.append((pos_l, pos_r)) + clip_start = pos_r + silence_start = None + # Deal with trailing silence. + total_frames = rms_list.shape[0] + if ( + silence_start is not None + and total_frames - silence_start >= self.min_interval + ): + silence_end = min(total_frames, silence_start + self.max_sil_kept) + pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start + sil_tags.append((pos, total_frames + 1)) + # Apply and return slices. + if len(sil_tags) == 0: + return [waveform] + else: + chunks = [] + if sil_tags[0][0] > 0: + chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0])) + for i in range(len(sil_tags) - 1): + chunks.append( + self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]) + ) + if sil_tags[-1][1] < total_frames: + chunks.append( + self._apply_slice(waveform, sil_tags[-1][1], total_frames) + ) + return chunks + + +def main(): + import os.path + from argparse import ArgumentParser + + import librosa + import soundfile + + parser = ArgumentParser() + parser.add_argument("audio", type=str, help="The audio to be sliced") + parser.add_argument( + "--out", type=str, help="Output directory of the sliced audio clips" + ) + parser.add_argument( + "--db_thresh", + type=float, + required=False, + default=-40, + help="The dB threshold for silence detection", + ) + parser.add_argument( + "--min_length", + type=int, + required=False, + default=5000, + help="The minimum milliseconds required for each sliced audio clip", + ) + parser.add_argument( + "--min_interval", + type=int, + required=False, + default=300, + help="The minimum milliseconds for a silence part to be sliced", + ) + parser.add_argument( + "--hop_size", + type=int, + required=False, + default=10, + help="Frame length in milliseconds", + ) + parser.add_argument( + "--max_sil_kept", + type=int, + required=False, + default=500, + help="The maximum silence length kept around the sliced clip, presented in milliseconds", + ) + args = parser.parse_args() + out = args.out + if out is None: + out = os.path.dirname(os.path.abspath(args.audio)) + audio, sr = librosa.load(args.audio, sr=None, mono=False) + slicer = Slicer( + sr=sr, + threshold=args.db_thresh, + min_length=args.min_length, + min_interval=args.min_interval, + hop_size=args.hop_size, + max_sil_kept=args.max_sil_kept, + ) + chunks = slicer.slice(audio) + if not os.path.exists(out): + os.makedirs(out) + for i, chunk in enumerate(chunks): + if len(chunk.shape) > 1: + chunk = chunk.T + soundfile.write( + os.path.join( + out, + f"%s_%d.wav" + % (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i), + ), + chunk, + sr, + ) + + +if __name__ == "__main__": + main() diff --git a/TorchJaekwon/Util/External/RVC/UtilAudio.py b/TorchJaekwon/Util/External/RVC/UtilAudio.py new file mode 100644 index 0000000000000000000000000000000000000000..5b60fc9cbabaa6a9df374de9a38f046f98706713 --- /dev/null +++ b/TorchJaekwon/Util/External/RVC/UtilAudio.py @@ -0,0 +1,15 @@ +import numpy as np + +class UtilAudio: + @staticmethod + def norm_audio(audio: np.ndarray, #[time] + max: float = 0.9, + alpha: float = 0.75, + ) -> np.ndarray: + tmp_max = np.abs(audio).max() + assert tmp_max <= 2.5, "The maximum value of the audio is too high." + + audio = (audio / tmp_max * (max * alpha)) + ( + 1 - alpha + ) * audio + return audio.astype(np.float32) \ No newline at end of file diff --git a/TorchJaekwon/Util/External/RVC/UtilF0.py b/TorchJaekwon/Util/External/RVC/UtilF0.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c214a59fe504a5cdc49b31dd066b82bfc545da --- /dev/null +++ b/TorchJaekwon/Util/External/RVC/UtilF0.py @@ -0,0 +1,93 @@ +try: import parselmouth +except: print('[Import Error] parselmouth') +try: import pyworld +except: print('[Import Error] pyworld') + +from typing import Literal + +import numpy as np + +class UtilF0(object): + def __init__(self, + samplerate=16000, + hop_size=160): + self.fs = samplerate + self.hop = hop_size + + self.f0_mel_bin = 256 + self.f0_max = 1100.0 + self.f0_min = 50.0 + self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700) + self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700) + + def printt(self, strr): + print(strr) + self.f.write("%s\n" % strr) + self.f.flush() + + def compute_f0(self, + audio:np.ndarray,#[time] + f0_method:Literal['pm','harvest','dio'] = 'harvest' + )->np.ndarray: #[time//hop] + p_len = audio.shape[0] // self.hop + if f0_method == "pm": + time_step = 160 / 16000 * 1000 + f0_min = 50 + f0_max = 1100 + f0 = ( + parselmouth.Sound(audio, self.fs) + .to_pitch_ac( + time_step=time_step / 1000, + voicing_threshold=0.6, + pitch_floor=f0_min, + pitch_ceiling=f0_max, + ) + .selected_array["frequency"] + ) + pad_size = (p_len - len(f0) + 1) // 2 + if pad_size > 0 or p_len - len(f0) - pad_size > 0: + f0 = np.pad( + f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant" + ) + elif f0_method == "harvest": + f0, t = pyworld.harvest( + audio.astype(np.double), + fs=self.fs, + f0_ceil=self.f0_max, + f0_floor=self.f0_min, + frame_period=1000 * self.hop / self.fs, + ) + f0 = pyworld.stonemask(audio.astype(np.double), f0, t, self.fs) + elif f0_method == "dio": + f0, t = pyworld.dio( + audio.astype(np.double), + fs=self.fs, + f0_ceil=self.f0_max, + f0_floor=self.f0_min, + frame_period=1000 * self.hop / self.fs, + ) + f0 = pyworld.stonemask(audio.astype(np.double), f0, t, self.fs) + elif f0_method == "rmvpe": + if hasattr(self, "model_rmvpe") == False: + from lib.rmvpe import RMVPE + + print("loading rmvpe model") + self.model_rmvpe = RMVPE("rmvpe.pt", is_half=False, device="cpu") + f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03) + return f0 + + def get_f0_mel_index(self, f0): + f0_mel = 1127 * np.log(1 + f0 / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * ( + self.f0_mel_bin - 2 + ) / (self.f0_mel_max - self.f0_mel_min) + 1 + + # use 0 or 1 + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > self.f0_mel_bin - 1] = self.f0_mel_bin - 1 + f0_coarse = np.rint(f0_mel).astype(int) + assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, ( + f0_coarse.max(), + f0_coarse.min(), + ) + return f0_coarse \ No newline at end of file diff --git a/TorchJaekwon/Util/Util.py b/TorchJaekwon/Util/Util.py new file mode 100644 index 0000000000000000000000000000000000000000..3154b6627d682c8bb5a24cd8d7dc5dfca86006fa --- /dev/null +++ b/TorchJaekwon/Util/Util.py @@ -0,0 +1,65 @@ +from typing import Literal +import os, sys + +PURPLE = '\033[95m' +CYAN = '\033[96m' +DARKCYAN = '\033[36m' +BLUE = '\033[94m' +GREEN = '\033[92m' +YELLOW = '\033[93m' +RED = '\033[91m' +BOLD = '\033[1m' +UNDERLINE = '\033[4m' +END = '\033[0m' + +class Util: + @staticmethod + def print(text:str, type:Literal['info', 'success', 'warning', 'error'] = None) -> None: + template_dict:dict = { + 'info': { + 'color': BOLD + BLUE, + 'prefix': '[Info]: ' + }, + 'success': { + 'color': BOLD + GREEN, + 'prefix': '[Success]: ' + }, + 'warning': { + 'color': BOLD + YELLOW, + 'prefix': '[Warning]: ' + }, + 'error': { + 'color': BOLD + RED, + 'prefix': '[Error]: ' + } + } + color:str = template_dict.get(type, {}).get('color', '') + prefix:str = template_dict.get(type, {}).get('prefix', '') + print(f"{color + prefix + text + END}") + + @staticmethod + def set_sys_path_to_parent_dir(file:str, # __file__ + depth_to_dir_from_file:int = 1, + ) -> None: + dir : str = os.path.abspath(os.path.dirname(file)) + for _ in range(depth_to_dir_from_file): dir = os.path.dirname(dir) + sys.path[0] = os.path.abspath(dir) + + @staticmethod + def system(command:str) -> None: + result_id:int = os.system(command) + assert result_id == 0, f'[Error]: Something wrong with the command [{command}]' + + @staticmethod + def wget(link:str, save_dir:str) -> None: + os.makedirs(save_dir, exist_ok=True) + Util.system(f'wget {link} -P {save_dir}') + + @staticmethod + def unzip(file_path:str, unzip_dir:str) -> None: + os.makedirs(unzip_dir, exist_ok=True) + file_name:str = file_path.split('/')[-1] + if '.tar.gz' in file_name: + Util.system(f'''tar -xzvf {file_path} -C {unzip_dir}''') + else: + Util.system(f'''unzip {file_path} -d {unzip_dir}''') \ No newline at end of file diff --git a/TorchJaekwon/Util/UtilAudio.py b/TorchJaekwon/Util/UtilAudio.py new file mode 100644 index 0000000000000000000000000000000000000000..42b806c35e6e00dbbb94f52b7d07dde28b3335ef --- /dev/null +++ b/TorchJaekwon/Util/UtilAudio.py @@ -0,0 +1,263 @@ +from typing import Optional, Literal, Union, Final, List +from numpy import ndarray + +import os +from tqdm import tqdm +import numpy as np +import soundfile as sf +import librosa +from scipy.signal import resample_poly + +try: import torch +except: print('import error: torch') +try: from torch import Tensor +except: print('') +try: import torchaudio +except: print('import error: torch') +try: from pydub import AudioSegment +except: print('import error: pydub') + +from TorchJaekwon.Util.UtilData import UtilData + +DATA_TYPE_MIN_MAX_DICT:Final[dict] = {'float32':(-1,1), 'float64':(-1,1), 'int16':(-2**15, 2**15-1), 'int32':(-2**31,2**31-1)} + +class UtilAudio: + @staticmethod + def change_dtype(audio:ndarray, + current_dtype:Literal['float32', 'float64', 'int16', 'int32'], + target_dtype:Literal['float32', 'float64', 'int16', 'int32'] + ) -> ndarray: + audio = np.clip(audio, a_min = DATA_TYPE_MIN_MAX_DICT[current_dtype][0], a_max = DATA_TYPE_MIN_MAX_DICT[current_dtype][1]) + audio = audio / DATA_TYPE_MIN_MAX_DICT[current_dtype][1] + audio = (audio * DATA_TYPE_MIN_MAX_DICT[target_dtype][1]) + audio = audio.astype(getattr(np,target_dtype)) + return audio + + @staticmethod + def resample_audio(audio:Union[ndarray, Tensor], #[shape=(channel, num_samples) or (num_samples)] + origin_sr:int, + target_sr:int, + resample_module:Literal['librosa', 'resample_poly', 'torchaudio'] = 'librosa', + resample_type:str = "kaiser_fast", + audio_path:Optional[str] = None): + if(origin_sr == target_sr): return audio + #print(f"resample audio {origin_sr} to {target_sr}") + if resample_module == 'librosa': + return librosa.resample(audio, orig_sr=origin_sr, target_sr=target_sr, res_type=resample_type) + elif resample_module == 'resample_poly': + return resample_poly(x = audio, up = target_sr, down = origin_sr) + elif resample_module == 'torchaudio': + #transforms.Resample precomputes and caches the kernel used for resampling, while functional.resample computes it on the fly + #so using torchaudio.transforms.Resample will result in a speedup when resampling multiple waveforms using the same parameters + return torchaudio.transforms.Resample(orig_freq = origin_sr, new_freq = target_sr)(audio) + + @staticmethod + def read(audio_path:str, + sample_rate:Optional[int] = None, + mono:Optional[bool] = None, + start_idx:int = 0, + end_idx:Optional[int] = None, + module_name:Literal['soundfile','librosa', 'torchaudio'] = 'torchaudio', + return_type:Union[ndarray, Tensor] = ndarray + ) -> Union[ndarray, Tensor]: #[shape=(channel, num_samples) or (num_samples)] + + if module_name == "soundfile": + audio_data, original_samplerate = sf.read(audio_path) + if len(audio_data.shape) > 1 : audio_data = audio_data.T + + if sample_rate is not None and sample_rate != original_samplerate: + #print(f"resample audio {original_samplerate} to {sample_rate}") + audio_data = UtilAudio.resample_audio(audio_data,original_samplerate,sample_rate) + + elif module_name == "librosa": + print(f"read audio sr: {sample_rate}") + audio_data, original_samplerate = librosa.load( audio_path, sr=sample_rate, mono=mono) + + elif module_name == 'torchaudio': + if end_idx is not None: assert end_idx > start_idx, f'[Error] end_idx must be larger than start_idx' + #[channel, time], int + audio_data, original_samplerate = torchaudio.load(audio_path, + frame_offset = start_idx, + num_frames = -1 if end_idx is None else end_idx - start_idx) + if sample_rate is not None and sample_rate != original_samplerate: + audio_data = UtilAudio.resample_audio(audio = audio_data, origin_sr=original_samplerate, target_sr = sample_rate, resample_module='torchaudio', audio_path = audio_path) + + if mono is not None: + if mono and len(audio_data.shape) == 2 and audio_data.shape[0] == 2: + audio_data = torch.mean(audio_data,axis=0) if isinstance(audio_data, torch.Tensor) else np.mean(audio_data,axis=0) + elif not mono and (len(audio_data.shape) == 1 or audio_data.shape[0] == 1): + stereo_audio = torch.zeros((2,len(audio_data.squeeze()))) + stereo_audio[0,...] = audio_data.squeeze() + stereo_audio[1,...] = audio_data.squeeze() + audio_data = stereo_audio + + assert ((len(audio_data.shape)==1) or ((len(audio_data.shape)==2) and audio_data.shape[0] in [1,2])),f'[read audio shape problem] path: {audio_path} shape: {audio_data.shape}' + + return audio_data, original_samplerate if sample_rate is None else sample_rate + + @staticmethod + def write(audio_path:str, + audio:Union[ndarray, Tensor], + sample_rate:int, + ) -> None: + os.makedirs(os.path.dirname(audio_path), exist_ok=True) + if isinstance(audio, Tensor): + audio = audio.squeeze().cpu().detach().numpy() + assert len(audio.shape) <= 2, f'[Error] shape of {audio_path}: {audio.shape}' + if len(audio.shape) == 2 and audio.shape[0] < audio.shape[1]: audio = audio.T + sf.write(file = audio_path, data = audio, samplerate = sample_rate) + + @staticmethod + def stereo_to_mono(audio_data:Union[ndarray, Tensor]) -> Union[ndarray, Tensor]: + audio_data = np.mean(audio_data,axis=1) + return audio_data + + @staticmethod + def mono_to_stereo(audio_data:Union[ndarray, Tensor]) -> Union[ndarray, Tensor]: + stereo_audio = np.zeros((2,len(audio_data))) + stereo_audio[0,...] = audio_data + stereo_audio[1,...] = audio_data + audio_data = stereo_audio + return audio_data + + @staticmethod + def normalize_volume(audio_input:ndarray,sr:int, target_dBFS = -30): + audio = UtilAudio.change_dtype(audio=audio_input,current_dtype='float64',target_dtype='int32')#UtilAudio.float64_to_int32(audio_input) + audio_segment = AudioSegment(audio.tobytes(), frame_rate=sr, sample_width=audio.dtype.itemsize, channels=1) + change_in_dBFS = target_dBFS - audio_segment.dBFS + normalizedsound = audio_segment.apply_gain(change_in_dBFS) + return UtilAudio.change_dtype(audio=np.array(normalizedsound.get_array_of_samples()),current_dtype='int32',target_dtype='float64') #UtilAudio.int32_to_float64(np.array(normalizedsound.get_array_of_samples())) + + @staticmethod + def normalize_by_fro_norm(audio_input:Tensor #[batch, channel, time] + ) -> Tensor: + original_shape:tuple = audio_input.shape + audio = audio_input.reshape(original_shape[0], -1) + audio = audio/torch.norm(audio, p="fro", dim=1, keepdim=True) + audio = audio.reshape(*original_shape) + return audio + + @staticmethod + def energy_unify(estimated, original, eps = 1e-12): + target = UtilAudio.pow_norm(estimated, original) * original + target /= UtilAudio.pow_p_norm(original) + eps + return estimated, target + + @staticmethod + def pow_norm(s1, s2): + return torch.sum(s1 * s2) + + @staticmethod + def pow_p_norm(signal): + return torch.pow(torch.norm(signal, p=2), 2) + + @staticmethod + def get_segment_index_list(audio:ndarray, #[time] + sample_rate:int, + segment_sample_length:int, + hop_seconds:float = 0.1 + ) -> list: + begin_sample:int = 0 + hop_samples = int(hop_seconds * sample_rate) + segment_index_list = list() + while (begin_sample == 0) or (begin_sample + segment_sample_length < len(audio)): + segment_index_list.append({'begin':begin_sample, 'end':begin_sample + segment_sample_length}) + begin_sample += hop_samples + return segment_index_list + + @staticmethod + def audio_to_batch(audio:Tensor, #[Length] + segment_length:int, + overlap_length:int = 48000 #recommend: int(sr * 0.5) + ): + assert len(audio.shape) == 1, f'[Error] audio shape must be 1, but {audio.shape}' + start_idx:int = 0 + audio_list = list() + while start_idx < len(audio): + audio_segment = audio[start_idx:start_idx+segment_length] + audio_segment = UtilData.fix_length(audio_segment, segment_length) + audio_list.append(audio_segment) + start_idx += segment_length - overlap_length + return torch.stack(audio_list) + + @staticmethod + def merge_batch_w_cross_fade(batch_audio:Union[List[ndarray],ndarray,Tensor], + segment_length:int, + overlap_length:int = 48000 #recommend: int(sr * 0.5) + ) -> ndarray: + ''' + reference from https://github.com/nkandpa2/music_enhancement/blob/master/scripts/generate_from_wav.py + ''' + if isinstance(batch_audio, ndarray) and len(batch_audio.shape) == 1: + batch_audio = [batch_audio] + output_audio_length:int = len(batch_audio) * segment_length - (len(batch_audio) - 1) * overlap_length + output_audio:Union[ndarray,Tensor] = torch.zeros(output_audio_length) if isinstance(batch_audio, torch.Tensor) else np.zeros(output_audio_length) + hop_length:int = segment_length - overlap_length + + cross_fade_in:ndarray = np.linspace(0, 1, overlap_length) + cross_fade_out:ndarray = 1 - cross_fade_in + if isinstance(batch_audio, torch.Tensor): + cross_fade_in = torch.tensor(cross_fade_in, device = batch_audio.device) + cross_fade_out = torch.tensor(cross_fade_out, device = batch_audio.device) + + for i in range(0,len(batch_audio)): + start_idx:int = i * hop_length + if i != 0: + batch_audio[i][:overlap_length] *= cross_fade_in + if i != len(batch_audio) - 1: + batch_audio[i][-overlap_length:] *= cross_fade_out + output_audio[start_idx:start_idx+segment_length] += batch_audio[i] + return output_audio + + @staticmethod + def analyze_audio_dataset(data_dir:str, + result_save_dir:str, + sanity_check_sr:Union[int,List[int]] = None, + save_each_meta:bool = False + ) -> None: + total_meta_dict:dict = { + 'total_duration_second': 0, + 'total_duration_minutes': 0, + 'total_duration_hours': 0, + + 'longest_sample_meta': { + 'file_name': '', + 'duration_second':0 + }, + + 'error_file_list': list() + } + if sanity_check_sr is not None: total_meta_dict['sample_rate'] = sanity_check_sr + + audio_meta_data_list = UtilData.walk(dir_name=data_dir, ext=['.wav', '.mp3', '.flac']) + for meta_data in tqdm(audio_meta_data_list): + try: + audio, sr = UtilAudio.read(meta_data['file_path'], mono=True) + except: + print(f'Error: {meta_data["file_path"]}') + total_meta_dict['error_file_list'].append(meta_data['file_path']) + continue + if sanity_check_sr is not None: + if isinstance(sanity_check_sr, int): assert sr == sanity_check_sr, f'''{meta_data['file_path']}'s sample rate is {sr}''' + if isinstance(sanity_check_sr, list): assert sr in sanity_check_sr, f'''{meta_data['file_path']}'s sample rate is {sr}''' + + meta_data_of_this_file = { + 'file_name': meta_data['file_name'], + 'file_path': os.path.abspath(meta_data['file_path']), + 'sample_length': audio.shape[-1], + 'sample_rate': sr, + } + meta_data_of_this_file['duration_second'] = meta_data_of_this_file['sample_length'] / meta_data_of_this_file['sample_rate'] + + save_dir:str = meta_data['dir_path'].replace(data_dir, result_save_dir) + if save_each_meta: UtilData.pickle_save(f'''{save_dir}/{meta_data['file_name']}.pkl''', meta_data_of_this_file) + + total_meta_dict['total_duration_second'] += meta_data_of_this_file['duration_second'] + if total_meta_dict['longest_sample_meta']['duration_second'] < meta_data_of_this_file['duration_second']: + total_meta_dict['longest_sample_meta'] = meta_data_of_this_file + + total_meta_dict['total_duration_minutes'] = total_meta_dict['total_duration_second'] / 60 + total_meta_dict['total_duration_hours'] = total_meta_dict['total_duration_second'] / 3600 + UtilData.yaml_save(save_path = f'{result_save_dir}/meta.yaml', data = total_meta_dict) + + diff --git a/TorchJaekwon/Util/UtilAudioMelSpec.py b/TorchJaekwon/Util/UtilAudioMelSpec.py new file mode 100644 index 0000000000000000000000000000000000000000..f2fba4a8f2cfa2818d351ab9963d19bd2a1c5e66 --- /dev/null +++ b/TorchJaekwon/Util/UtilAudioMelSpec.py @@ -0,0 +1,111 @@ +#type +from typing import Union +from numpy import ndarray +from torch import Tensor +#package +import os +import torch +import numpy as np +import librosa.display +from librosa.filters import mel as librosa_mel_fn +try: + import matplotlib.pyplot as plt +except: + print('matplotlib is uninstalled') +#torchjaekwon +from TorchJaekwon.Util.UtilAudioSTFT import UtilAudioSTFT +from TorchJaekwon.Util.UtilTorch import UtilTorch + +class UtilAudioMelSpec(UtilAudioSTFT): + def __init__(self, + nfft: int, + hop_size: int, + sample_rate:int, + mel_size:int, + frequency_min:float, + frequency_max:float): + super().__init__(nfft, hop_size) + + self.sample_rate:int = sample_rate + self.mel_size:int = mel_size + self.frequency_min:float = frequency_min + self.frequency_max:float = frequency_max if frequency_max is not None else sample_rate//2 + + #[self.mel_size, self.nfft//2 + 1] + self.mel_basis_np:ndarray = librosa_mel_fn(sr = self.sample_rate, + n_fft = self.nfft, + n_mels = self.mel_size, + fmin = self.frequency_min, + fmax = self.frequency_max) + self.mel_basis_tensor:Tensor = torch.from_numpy(self.mel_basis_np).float() + self.mel_frequncies = librosa.mel_frequencies(n_mels = self.mel_size, + fmin = self.frequency_min, + fmax = self.frequency_max) + + @staticmethod + def get_default_mel_spec_config(sample_rate:int = 16000) -> dict: + nfft:int = 1024 if sample_rate <= 24000 else 2048 + mel_size:int = 80 if sample_rate <= 24000 else 128 + return {'nfft': nfft, 'hop_size': nfft//4, 'sample_rate': sample_rate, 'mel_size': mel_size, 'frequency_max': sample_rate//2, 'frequency_min': 0} + + def spec_to_mel_spec(self,stft_mag): + if type(stft_mag) == np.ndarray: + return np.matmul(self.mel_basis_np, stft_mag) + elif type(stft_mag) == torch.Tensor: + self.mel_basis_tensor = self.mel_basis_tensor.to(stft_mag.device) + return torch.matmul(self.mel_basis_tensor, stft_mag) + else: + print("spec_to_mel_spec type error") + exit() + + def dynamic_range_compression(self, x, C=1, clip_val=1e-5): + if type(x) == np.ndarray: + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + elif type(x) == torch.Tensor: + return torch.log(torch.clamp(x, min=clip_val) * C) + else: + print("dynamic_range_compression type error") + exit() + + def get_hifigan_mel_spec(self, + audio:Union[ndarray,Tensor], #[Batch,Time] + return_type:str=['ndarray','Tensor'][1] + ) -> Union[ndarray,Tensor]: + if isinstance(audio,ndarray): audio = torch.FloatTensor(audio) + while len(audio.shape) < 2: audio = audio.unsqueeze(0) + + if torch.min(audio) < -1.: + print('min value is ', torch.min(audio)) + if torch.max(audio) > 1.: + print('max value is ', torch.max(audio)) + + spectrogram = self.stft_torch(audio)["mag"] + mel_spec = self.spec_to_mel_spec(spectrogram) + log_scale_mel = self.dynamic_range_compression(mel_spec) + + if return_type == 'ndarray': + return log_scale_mel.cpu().detach().numpy() + else: + return log_scale_mel + + def mel_spec_plot(self, + save_path:str, #'*.png' + mel_spec:ndarray, #[mel_size, time] + fig_size:tuple=(8,4), + dpi:int = 500) -> None: + assert(os.path.splitext(save_path)[1] == ".png") , "file extension should be '.png'" + if isinstance(mel_spec, Tensor): + mel_spec = UtilTorch.to_np(mel_spec) + plt.figure(figsize=fig_size) + plt.imshow(mel_spec, origin='lower', aspect='auto', cmap='viridis') + plt.savefig(save_path,dpi=dpi) + plt.close() + + def f0_to_melbin(self, + f0:Tensor # 1d f0 tensor + ) -> Tensor: + mel_frequencies = torch.FloatTensor(self.mel_frequncies).repeat(f0.shape[0]).reshape(f0.shape[0],-1).to(f0.device) + mel_frequencies[((mel_frequencies - f0.unsqueeze(-1)) < 0)] = np.inf + all_inf_value = torch.all(torch.isinf(mel_frequencies), dim = 1) + mel_frequencies[all_inf_value,-1] = 0 + return torch.argmin(mel_frequencies, dim=1) diff --git a/TorchJaekwon/Util/UtilAudioPlus.py b/TorchJaekwon/Util/UtilAudioPlus.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3f174c49d15dc2ade75a76b647aadad1041eca --- /dev/null +++ b/TorchJaekwon/Util/UtilAudioPlus.py @@ -0,0 +1,62 @@ +from typing import Dict +from numpy import ndarray + +import resampy +import torch +import torchcrepe +import numpy as np + +from TorchJAEKWON.DataProcess.Util.UtilAudio import UtilAudio + +class UtilAudioPlus(UtilAudio): + def get_pitch_crepe(self, + wav:ndarray, #mono 1d array + sample_rate:float, + hop_size:int, + spec_time_bin_length:int, + f0_min:float = 50.0, + f0_max:float = 1100.0, + threshold:float=0.05, + device = torch.device("cuda")) -> Dict[str,ndarray]: + + wav16k = resampy.resample(wav, sample_rate, 16000) + wav16k_torch = torch.FloatTensor(wav16k).unsqueeze(0).to(device) + + f0, pd = torchcrepe.predict(wav16k_torch, 16000, 80, f0_min, f0_max, pad=True, model='full', batch_size=1024, device=device, return_periodicity=True) + + pd = torchcrepe.filter.median(pd, 3) + pd = torchcrepe.threshold.Silence(-60.)(pd, wav16k_torch, 16000, 80) + f0 = torchcrepe.threshold.At(threshold)(f0, pd) + f0 = torchcrepe.filter.mean(f0, 3) + + f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0) + + nzindex = torch.nonzero(f0[0]).squeeze() + f0 = torch.index_select(f0[0], dim=0, index=nzindex).cpu().numpy() + time_org = 0.005 * nzindex.cpu().numpy() + time_frame = np.arange(spec_time_bin_length) * hop_size / sample_rate + if f0.shape[0] == 0: + f0 = torch.FloatTensor(time_frame.shape[0]).fill_(0) + print('f0 all zero!') + else: + f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) + pitch_coarse = self.f0_to_coarse(f0) + return {'f0':f0, 'pitch':pitch_coarse} + + def f0_to_coarse(self, + f0:ndarray, + f0_bin:int = 256, + f0_min:float = 50.0, + f0_max:float = 1100.0) -> ndarray: + + is_torch = isinstance(f0, torch.Tensor) + f0_mel_min = 1127 * np.log(1 + f0_min / 700) + f0_mel_max = 1127 * np.log(1 + f0_max / 700) + f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 + + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 + f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(int) + assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min()) + return f0_coarse \ No newline at end of file diff --git a/TorchJaekwon/Util/UtilAudioSTFT.py b/TorchJaekwon/Util/UtilAudioSTFT.py new file mode 100644 index 0000000000000000000000000000000000000000..a860f97b34cbc60f50757f31392e642f695c9e36 --- /dev/null +++ b/TorchJaekwon/Util/UtilAudioSTFT.py @@ -0,0 +1,121 @@ +#type +from typing import Union,Dict +from numpy import ndarray +from torch import Tensor + +import matplotlib.pyplot as plt +import numpy as np +import torch +import librosa +import librosa.display + +from TorchJaekwon.Util.UtilAudio import UtilAudio + +class UtilAudioSTFT(UtilAudio): + def __init__(self,nfft:int, hop_size:int): + super().__init__() + self.nfft = nfft + self.hop_size = hop_size + self.hann_window = torch.hann_window(self.nfft) + + def get_mag_phase_stft_np(self,audio): + stft = librosa.stft(audio,n_fft=self.nfft, hop_length=self.hop_size) + mag = abs(stft) + phase = np.exp(1.j * np.angle(stft)) + return {"mag":mag,"phase":phase} + + def get_mag_phase_stft_np_mono(self,audio): + if audio.shape[0] == 2: + return self.get_mag_phase_stft_np(np.mean(audio,axis=0)) + else: + return self.get_mag_phase_stft_np(audio) + + + def stft_torch(self, + audio:Union[ndarray,Tensor] # [time] or [batch, time] or [batch, channel, time] + ) -> Dict[str,Tensor]: + + audio_torch:Tensor = torch.from_numpy(audio) if type(audio) == np.ndarray else audio + + assert(len(audio_torch.shape) <= 3), f'Error: stft_torch() audio torch shape is {audio_torch.shape}' + + if (len(audio_torch.shape) == 1): audio_torch = audio_torch.unsqueeze(0) + + shape_is_three = True if len(audio_torch.shape) == 3 else False + if shape_is_three: + batch_size, channels_num, segment_samples = audio_torch.shape + audio_torch = audio_torch.reshape(batch_size * channels_num, segment_samples) + + spec_dict:Dict[str,Tensor] = dict() + + audio_torch = torch.nn.functional.pad(audio_torch.unsqueeze(1), (int((self.nfft-self.hop_size)/2), int((self.nfft-self.hop_size)/2)), mode='reflect').squeeze(1) + spec_dict['stft'] = torch.stft(audio_torch, + self.nfft, + hop_length=self.hop_size, + window=self.hann_window.to(audio_torch.device), + center=False, + pad_mode='reflect', + normalized=False, + onesided=True, + return_complex=True) + ''' + spec_dict['stft'] = torch.stft(audio_torch, + n_fft=self.nfft, + hop_length=self.hop_size, + window=self.hann_window.to(audio_torch.device), + return_complex=True) + ''' + spec_dict['mag'] = spec_dict['stft'].abs() + spec_dict['angle'] = spec_dict['stft'].angle() + + if shape_is_three: + _, time_steps, freq_bins = spec_dict['stft'].shape + for feature_name in spec_dict: + spec_dict[feature_name] = spec_dict[feature_name].reshape(batch_size, channels_num, time_steps, freq_bins) + + return spec_dict + + def istft_torch_from_mag_and_angle(self, + mag:Tensor, + angel:Tensor): + stft_complex:Tensor = torch.polar(abs = mag, angle = angel) + return torch.istft(stft_complex, self.nfft, hop_length=self.hop_size,window=self.hann_window.to(stft_complex.device), + center=True, onesided=True) + + def get_pred_accom_by_subtract_pred_vocal_audio(self,pred_vocal,mix_audio): + pred_vocal_mag = self.get_mag_phase_stft_np_mono(pred_vocal)["mag"] + mix_stft = self.get_mag_phase_stft_np_mono(mix_audio) + mix_mag = mix_stft["mag"] + mix_phase = mix_stft["phase"] + pred_accom_mag = mix_mag - pred_vocal_mag + pred_accom_mag[pred_accom_mag < 0] = 0 + pred_accom = librosa.istft(pred_accom_mag*mix_phase,hop_length=self.hop_size,length=mix_audio.shape[-1]) + return pred_accom + + def stft_plot_from_audio_path(self,audio_path:str,save_path:str = None, dpi:int = 500) -> None: + audio, sr = librosa.load(audio_path) + stft_audio:ndarray = librosa.stft(audio) + spectrogram_db_scale:ndarray = librosa.amplitude_to_db(np.abs(stft_audio), ref=np.max) + plt.figure(dpi=dpi) + librosa.display.specshow(spectrogram_db_scale) + plt.colorbar() + if save_path is not None: + plt.savefig(save_path,dpi=dpi) + + @staticmethod + def spec_to_figure(spec, + vmin:float = -6.0, + vmax:float = 1.5, + fig_size:tuple = (12,6), + dpi = 400, + transposed=False, + save_path=None): + if isinstance(spec, torch.Tensor): + spec = spec.squeeze().cpu().numpy() + spec = spec.squeeze() + fig = plt.figure(figsize=fig_size, dpi = dpi) + plt.pcolor(spec.T if transposed else spec, vmin=vmin, vmax=vmax) + if save_path is not None: + plt.savefig(save_path,dpi=dpi) + plt.close() + return fig \ No newline at end of file diff --git a/TorchJaekwon/Util/UtilAudioVocalPresence.py b/TorchJaekwon/Util/UtilAudioVocalPresence.py new file mode 100644 index 0000000000000000000000000000000000000000..fe63157693e2515cf29b4186a7e5fd7ad1d6190b --- /dev/null +++ b/TorchJaekwon/Util/UtilAudioVocalPresence.py @@ -0,0 +1,33 @@ +from torch import Tensor + +import torch + +from DataProcess.Util.UtilAudio import UtilAudio + +class UtilAudioVocalPresence(UtilAudio): + def __init__(self) -> None: + super().__init__() + self.energy_threshold:float = 4 #1 + + def get_vocal_presence_from_raw_vocal_spec( + self, + raw_vocal_spec_mag:Tensor # (batch, audio_channel, frequency, time) + ) -> Tensor: + + freq_dim = len(raw_vocal_spec_mag.shape) - 2 + + vocal_energy:Tensor = raw_vocal_spec_mag * raw_vocal_spec_mag + vocal_energy = torch.sum(vocal_energy, dim=freq_dim, keepdim=True) + vocal_energy[vocal_energy=self.energy_threshold] = 1 + + return vocal_energy + + def apply_vocal_presence_to_spec( + self, + raw_vocal_spec_mag:Tensor, # (batch, audio_channel, frequency, time) + voice_presence:Tensor + )->Tensor: + assert len(raw_vocal_spec_mag.shape) == 4, f"The shape of raw_vocal_spec_mag shpuld be (batch, audio_channel, frequency, time)" + vocal_presence_mask:Tensor = voice_presence.repeat(1,1,raw_vocal_spec_mag.shape[2],1) + return raw_vocal_spec_mag * vocal_presence_mask \ No newline at end of file diff --git a/TorchJaekwon/Util/UtilData.py b/TorchJaekwon/Util/UtilData.py new file mode 100644 index 0000000000000000000000000000000000000000..7103a20b681be231538fc1ad173b6c1531a564bc --- /dev/null +++ b/TorchJaekwon/Util/UtilData.py @@ -0,0 +1,215 @@ +from typing import Union,Dict,List +from numpy import ndarray +from torch import Tensor + +import os +from tqdm import tqdm +import random +import copy +import numpy as np +import torch +import torch.nn.functional as F +import pickle, yaml, csv, json +from pathlib import Path +from inspect import isfunction + +class UtilData: + + @staticmethod + def get_file_name(file_path:str, with_ext:bool = False) -> str: + if file_path is None: + print("warning: path is None") + return "" + path_pathlib = Path(file_path) + if with_ext: + return path_pathlib.name + else: + return path_pathlib.stem + + @staticmethod + def pickle_save(save_path:str, data:Union[ndarray,Tensor]) -> None: + if not (os.path.splitext(save_path)[1] == ".pkl"): + print("file extension should be '.pkl'") + save_path = f'{save_path}.pkl' + + os.makedirs(os.path.dirname(save_path),exist_ok=True) + + with open(save_path,'wb') as file_writer: + pickle.dump(data,file_writer) + + @staticmethod + def pickle_load(data_path:str) -> Union[ndarray,Tensor]: + with open(data_path, 'rb') as pickle_file: + data:Union[ndarray,Tensor] = pickle.load(pickle_file) + return data + + @staticmethod + def yaml_save(save_path:str, data:Union[dict,list], sort_keys:bool = False) -> None: + assert(os.path.splitext(save_path)[1] == ".yaml") , "file extension should be '.yaml'" + + with open(save_path, 'w') as file: + yaml.dump(data, file, sort_keys = sort_keys, allow_unicode=True) + + @staticmethod + def yaml_load(data_path:str) -> dict: + yaml_file = open(data_path, 'r') + return yaml.safe_load(yaml_file) + + @staticmethod + def csv_load(data_path:str) -> list: + row_result_list = list() + with open(data_path, newline='') as csvfile: + spamreader = csv.reader(csvfile)#, delimiter=' ', quotechar='|') + for row in spamreader: + row_result_list.append(row) + return row_result_list + + @staticmethod + def txt_load(data_path:str) -> list: + with open(data_path, 'r') as txtfile: + return txtfile.readlines() + + @staticmethod + def txt_save(save_path:str, string_list:List[str], new_file:bool = True) -> list: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + with open(save_path, 'w' if new_file else 'a') as file: + for line in string_list: + file.write(f'{line}\n') + + @staticmethod + def csv_save(file_path:str, + data_dict_list:List[Dict[str,object]], #[ {key:object}, ... ] + order_of_key:list = None # [key1, key2, ...] + ) -> list: + import pandas as pd + if order_of_key is None: + order_of_key = list(data_dict_list[0].keys()) + csv_save_dict:dict = {key:list() for key in order_of_key} + for data_dict in data_dict_list: + for key in csv_save_dict: + csv_save_dict[key].append(data_dict[key]) + pd.DataFrame(csv_save_dict).to_csv(file_path) + + @staticmethod + def json_load(file_path:str) -> dict: + with open(file_path) as f: data = f.read() + return json.loads(data) + + @staticmethod + def save_data_segment(save_dir:str,data:ndarray,segment_len:int,segment_axis:int=-1,remainder:str = ['discard','pad','maintain'][1],ext:str = ['pkl'][0]): + os.makedirs(save_dir,exist_ok=True) + data_total = copy.deepcopy(data) + total_length_of_data:int = data_total.shape[segment_axis] + + if total_length_of_data % segment_len != 0 and remainder in ['discard','pad']: + if remainder == 'discard': + data_total = data_total.take(indices=range(0, total_length_of_data - (total_length_of_data % segment_len)), axis=segment_axis) + else: + assert(segment_axis==-1 and (len(data_total.shape) in [1,2])),'Error[UtilData.save_data_segment] not implemented yet' + pad_length:int = segment_len - (total_length_of_data % segment_len) + if len(data_total.shape) == 1: + data_total = np.pad(data_total, (0, pad_length), 'constant') + elif len(data_total.shape) == 2: + data_total = np.pad(data_total, ((0,0),(0,pad_length)), 'constant') + total_length_of_data:int = data_total.shape[segment_axis] + + for start_idx in range(0,total_length_of_data,segment_len): + end_idx:int = start_idx + segment_len + if remainder == 'maintain' and end_idx >= total_length_of_data: end_idx = total_length_of_data - 1 + + data_segment = data_total.take(indices=range(start_idx, end_idx), axis=segment_axis) + + assert(data_segment.shape[segment_axis] == segment_len),'Error[UtilData.save_data_segment] segment length error!!' + if ext == 'pkl': + UtilData.pickle_save(f'{save_dir}/{start_idx}.{ext}',data_segment) + + @staticmethod + def fit_shape_length(feature:Union[Tensor,ndarray],shape_length:int, dim:int = 0) -> Tensor: + if shape_length == len(feature.shape): + return feature + if type(feature) != torch.Tensor: + feature = torch.from_numpy(feature) + + feature = torch.squeeze(feature) + + for _ in range(shape_length - len(feature.shape)): + feature = torch.unsqueeze(feature, dim=dim) + + return feature + + @staticmethod + def sort_dict_list( dict_list: List[dict], key:str, reverse:bool = False): + return sorted(dict_list, key = lambda dictionary: dictionary[key], reverse=reverse) + + @staticmethod + def random_segment(data:ndarray, data_length:int) -> ndarray: + max_data_start = len(data) - data_length + data_start = random.randint(0, max_data_start) + return data[data_start:data_start+data_length] + + @staticmethod + def default(val, d): + if val is not None: + return val + return d() if isfunction(d) else d + + @staticmethod + def fix_length(data:Union[ndarray,Tensor], + length:int, + dim:int = -1 + ) -> Tensor: + assert len(data.shape) in [1,2,3], "Error[UtilData.fix_length] only support when data.shape is 1, 2 or 3" + if data.shape[dim] < length: + if isinstance(data,Tensor): + return F.pad(data, (0,length - data.shape[dim]), "constant", 0) + else: + return F.pad(torch.from_numpy(data), (0,length - data.shape[dim]), "constant", 0).numpy() + elif data.shape[dim] == length: + return data + else: + assert dim == -1, "Error[UtilData.fix_length] slicing when dim is not -1 not implemented yet" + return data[..., :length] + + @staticmethod + def listdir(dir_name:str, ext:Union[str,list] = ['.wav', '.mp3', '.flac']) -> list: + if ext is None: + return os.listdir(dir_name) + elif isinstance(ext,list): + return [{'file_name': file_name, 'file_path':f'{dir_name}/{file_name}'} for file_name in os.listdir(dir_name) if os.path.splitext(file_name)[1] in ext] + else: + return [{'file_name': file_name, 'file_path':f'{dir_name}/{file_name}'} for file_name in os.listdir(dir_name) if os.path.splitext(file_name)[1] == ext] + + @staticmethod + def walk(dir_name:str, ext:list = ['.wav', '.mp3', '.flac']) -> list: + file_meta_list:list = list() + for root, _, files in os.walk(dir_name): + for filename in tqdm(files, desc=f'walk {root}'): + if os.path.splitext(filename)[-1] in ext: + file_meta_list.append({ + 'file_name': UtilData.get_file_name( file_path = filename ), + 'file_path': f'{root}/{filename}', + 'dir_name': root.replace(dir_name,'').replace('/',''), + 'dir_path': root, + }) + return file_meta_list + + @staticmethod + def get_dir_name_list(root_dir:str) -> list: + return [dir_name for dir_name in os.listdir(root_dir) if os.path.isdir(f'{root_dir}/{dir_name}')] + + @staticmethod + def pretty_num(number:float) -> str: + if number < 1000: + return str(number) + elif number < 1000000: + return f'{round(number/1000,5)}K' + elif number < 1000000000: + return f'{round(number/1000000,5)}M' + else: + return f'{round(number/1000000000,5)}B' + + @staticmethod + def extract_num_from_str(string:str) -> float: + return float(''.join([c for c in string if c.isdigit() or c == '.'])) + + diff --git a/TorchJaekwon/Util/UtilDiffusion.py b/TorchJaekwon/Util/UtilDiffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..06e2ebeaf505dc2dd5c9311d7c947a50fde93923 --- /dev/null +++ b/TorchJaekwon/Util/UtilDiffusion.py @@ -0,0 +1,3 @@ +class UtilDiffusion: + def __init__(self) -> None: + pass \ No newline at end of file diff --git a/TorchJaekwon/Util/UtilHiFiGanWrapper.py b/TorchJaekwon/Util/UtilHiFiGanWrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..6f8a2af7557c5509689ecc90f43e52a6c3fa634e --- /dev/null +++ b/TorchJaekwon/Util/UtilHiFiGanWrapper.py @@ -0,0 +1,64 @@ +from typing import Union +from torch import Tensor +from numpy import ndarray + +import os +import json +import torch + +from HParams import HParams +from DataProcess.Util.UtilAudioMelSpec import UtilAudioMelSpec +from Model.vocoder.hifigan.env import AttrDict +from Model.vocoder.hifigan.models import Generator + +class UtilHiFiGanWrapper: + def __init__(self,h_params:HParams): + self.h_params:HParams = h_params + + self.util_mel = UtilAudioMelSpec(self.h_params) + + self.hifi_gan_generator : Generator + self.load_hifi_gan() + + def load_hifi_gan(self): + pretrain_name = self.h_params.process.hi_fi_gan_pretrained_name_list[self.h_params.process.hi_fi_gan_pretrain_idx] + pretrain_path = "./Model/vocoder/hifigan/pretrained/" + pretrain_name + config_file = os.path.join(pretrain_path, 'config.json') + with open(config_file) as f: + data = f.read() + json_config = json.loads(data) + h = AttrDict(json_config) + self.hifi_gan_generator = Generator(h) + if "MUS_DB" in pretrain_name: + state_dict_g = torch.load(pretrain_path + "/generator.pt", map_location='cpu') + self.hifi_gan_generator.load_state_dict(state_dict_g) + else: + state_dict_g = torch.load(pretrain_path + "/generator", map_location='cpu') + self.hifi_gan_generator.load_state_dict(state_dict_g['generator']) + self.hifi_gan_generator = self.hifi_gan_generator.to(self.h_params.resource.device) + self.hifi_gan_generator.eval() + self.hifi_gan_generator.remove_weight_norm() + + def audio_to_hifi_gan_mel(self,audio:Union[Tensor,ndarray]) -> Tensor: + audio_tensor:Tensor = audio if type(audio) == Tensor else torch.from_numpy(audio) + spectrogram:Tensor = self.util_mel.stft_torch(audio_tensor)["mag"] + mel:Tensor = self.util_mel.spec_to_mel_spec(spectrogram) + return self.util_mel.dynamic_range_compression(mel) + + def generate_audio_by_hifi_gan(self,input_feature:Union[Tensor,ndarray]) -> ndarray: + final_shape_len = 3 + + if type(input_feature) != torch.Tensor: + input_feature = torch.from_numpy(input_feature) + + for _ in range(final_shape_len - len(input_feature.shape)): + input_feature = torch.unsqueeze(input_feature, 0) + + input_feature = input_feature.to(self.h_params.resource.device) + + with torch.no_grad(): + #in: (batch,mel_size,time) , out: (batch,channel,time) + audio = self.hifi_gan_generator(input_feature) + audio = audio.squeeze() + audio = audio.cpu().numpy() + return audio \ No newline at end of file diff --git a/TorchJaekwon/Util/UtilTorch.py b/TorchJaekwon/Util/UtilTorch.py new file mode 100644 index 0000000000000000000000000000000000000000..6cca34029f07c6bedaf8d2636614c51fe3964505 --- /dev/null +++ b/TorchJaekwon/Util/UtilTorch.py @@ -0,0 +1,130 @@ +from typing import Any,Dict +from torch import Tensor, dtype, device +from numpy import ndarray + +import os +from collections import OrderedDict +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt +from sklearn.manifold import TSNE + +class UtilTorch: + @staticmethod + def to_np(tensor:Tensor, do_squeeze:bool = True) -> ndarray: + if do_squeeze: + return tensor.squeeze().detach().cpu().numpy() + else: + return tensor.detach().cpu().numpy() + + @staticmethod + def to_torch(numpy_array:ndarray, dtype:dtype = torch.float32) -> Tensor: + return torch.tensor(numpy_array, dtype=dtype) + + @staticmethod + def register_buffer(model:nn.Module, + variable_name:str, + value:Any, + dtype:dtype = torch.float32) -> Any: + if type(value) != Tensor: + value = torch.tensor(value, dtype=dtype) + model.register_buffer(variable_name, value) + return getattr(model,variable_name) + + @staticmethod + def get_param_num(model:nn.Module) -> Dict[str,int]: + num_param : int = sum(param.numel() for param in model.parameters()) + trainable_param : int = sum(param.numel() for param in model.parameters() if param.requires_grad) + return {'total':num_param, 'trainable':trainable_param} + + @staticmethod + def freeze_param(model:nn.Module) -> nn.Module: + model = model.eval() + model.train = lambda self: self #override train with useless function + for param in model.parameters(): + param.requires_grad = False + return model + + @staticmethod + def get_model_device(model:nn.Module) -> device: + return next(model.parameters()).device + + @staticmethod + def interpolate_2d(input:Tensor, #[width, height] | [batch, width, height] | [batch, channels, width, height] + size_after_interpolation:tuple, #(width, height) + mode:str = 'nearest' + ) -> Tensor: + if len(input.shape) == 2: + shape_after_interpolation = size_after_interpolation + input = input.view(1,1,*(input.shape)) + elif len(input.shape) == 3: + shape_after_interpolation = (input.shape[0],*(size_after_interpolation)) + input = input.unsqueeze(1) + elif len(input.shape) == 4: + shape_after_interpolation = (input.shape[0],input.shape[1],*(size_after_interpolation)) + return F.interpolate(input, size = size_after_interpolation, mode=mode).view(shape_after_interpolation) + + @staticmethod + def tsne_plot(save_file_path:str, + class_array:ndarray, #[the number of data, 1] data must be integer for class. ex) [[1],[3],...] + embedding_array:ndarray, #[the number of data, channel_size] + figure_size:tuple = (10,10), + legend:str = 'full', + point_size:float = None #s=200 + ) -> None: + import pandas as pd + import seaborn as sns + assert os.path.splitext(save_file_path)[-1] == '.png', 'save_file_path should be *.png' + + print('generating t-SNE plot...') + tsne = TSNE(random_state=0) + tsne_output:ndarray = tsne.fit_transform(embedding_array) + + df = pd.DataFrame(tsne_output, columns=['x', 'y']) + df['class'] = class_array + + plt.rcParams['figure.figsize'] = figure_size + + scatterplot_args:dict = {'x':'x', 'y':'y', 'hue':'class', 'palette':sns.color_palette("hls", 10), + 'data':df, 'marker':'o', 'legend':legend, 'alpha':0.5} + if point_size is not None: scatterplot_args['s'] = point_size + sns.scatterplot(**scatterplot_args) + + plt.xticks([]) + plt.yticks([]) + plt.xlabel('') + plt.ylabel('') + + plt.savefig(save_file_path, bbox_inches='tight') + + @staticmethod + def update_ema(ema_model:nn.Module, model:nn.Module, decay:float=0.9999) -> None: + """ + Step the EMA model towards the current model. + """ + with torch.no_grad(): + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + name = name.replace("module.", "") + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + + @staticmethod + def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + @staticmethod + def kl_div_gaussian(mean1:Tensor, logvar1:Tensor, mean2:Tensor, logvar2:Tensor) -> Tensor: + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + + return 0.5 * ( -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)) \ No newline at end of file diff --git a/TorchJaekwon/Util/UtilVideo.py b/TorchJaekwon/Util/UtilVideo.py new file mode 100644 index 0000000000000000000000000000000000000000..8e6a55e0a8e0c03ea9f5356c83fe1c5f8a93d7be --- /dev/null +++ b/TorchJaekwon/Util/UtilVideo.py @@ -0,0 +1,69 @@ +from typing import Literal, Tuple + +import os +import subprocess +import numpy as np +from moviepy.editor import VideoFileClip, AudioFileClip, ImageClip +try: from pydub import AudioSegment +except: print('pydub is not installed. Please install it using `pip install pydub`') + +class UtilVideo: + @staticmethod + def extract_audio_from_video(video_path:str, + output_path:str = './tmp/tmp.wav', + ) -> str: + video = VideoFileClip(video_path) + audio = video.audio + os.makedirs(os.path.dirname(output_path), exist_ok=True) + audio.write_audiofile(output_path, codec='pcm_s16le') + return output_path + + @staticmethod + def attach_audio_to_video(video_path:str, + audio_path:str, + output_path:str, + fps:int=30, + video_duration_sec:float = None, + audio_codec:Literal['aac', 'pcm_s16le', 'pcm_s32le'] = 'aac', + ) -> VideoFileClip: + os.makedirs(os.path.dirname(output_path), exist_ok=True) + video_clip = VideoFileClip(video_path).set_fps(fps) + if video_duration_sec is not None: + video_clip = video_clip.subclip(0, video_duration_sec) + video_clip = video_clip.set_audio(AudioFileClip(audio_path)) + video_clip.write_videofile( + output_path, + audio=True, + audio_codec = audio_codec, + fps=fps, + verbose=False, + logger=None + ) + return video_clip + + @staticmethod + def attach_audio_to_img(image_path:str, + audio_path:str, + output_path:str = 'output.mkv', + audio_codec:Literal['aac', 'pcm_s16le', 'pcm_s32le'] = 'pcm_s32le', + audio_fps:int=44100, + video_size:Tuple[int,int]=(1920,1080), + module:Literal['moviepy', 'ffmpeg'] = 'moviepy' + ): + if module == 'moviepy': + import PIL + PIL.Image.ANTIALIAS = PIL.Image.LANCZOS + audio = AudioFileClip(audio_path) + image_clip:ImageClip = ImageClip(image_path).set_duration(audio.duration).resize(newsize=video_size) + video = image_clip.set_audio(audio) + video.write_videofile(output_path, + codec='libx264', + audio_fps = audio_fps, + audio_codec=audio_codec, + fps=24) + elif module == 'ffmpeg': + subprocess.run([ + 'ffmpeg', '-loop', '1', '-i', image_path, '-i', audio_path, + '-vf', f'scale={video_size[0]}:{video_size[1]}', '-c:v', 'libx264', + '-c:a', 'aac', '-shortest', output_path + ]) diff --git a/TorchJaekwon/Util/UtilWorldVocoder.py b/TorchJaekwon/Util/UtilWorldVocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d419ea6e31bd36f6dea7929d9174628d375d187c --- /dev/null +++ b/TorchJaekwon/Util/UtilWorldVocoder.py @@ -0,0 +1,236 @@ +import librosa +import numpy as np +import pyworld as pw +import pysptk.sptk as pysptk +import torch + +from HParams import HParams +class UtilWorldVocoder: + def __init__(self,h_params:HParams): + self.h_params = h_params + self.sample_rate = self.h_params.preprocess.sample_rate + self.n_fft = self.h_params.preprocess.nfft + self.hop_length = self.h_params.preprocess.hopsize + self.window_size = self.n_fft + self.world_frame_period = (self.hop_length / self.sample_rate) * 1000 + + def mag_phase_stft(self,audio): + stft = librosa.stft(audio,n_fft=self.h_params.preprocess.nfft, hop_length=self.h_params.preprocess.hopsize) + mag = abs(stft) + phase = np.exp(1.j * np.angle(stft)) + return {"mag":mag,"phase":phase} + + + def dynamic_range_compression(self, x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + def dynamic_range_compression_torch(self, x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + def normalize(self,x, min_db = -80.0, max_db = 20.0, clip_val = 0.8): + x = 2.0*(x - min_db)/(max_db - min_db) - 1.0 + x = torch.clamp(clip_val*x, -clip_val, clip_val) + return x + + def denormalize(self, x, min_db = -80.0, max_db = 20.0, clip_val = 0.8): + x = x/clip_val + x = (max_db - min_db)*(x + 1.0)/2.0 + min_db + return x + + def get_pred_accom_by_subtract_pred_vocal(self,pred_vocal,is_pred_vocal_audio,mix_audio): + pred_vocal_mag = pred_vocal + if is_pred_vocal_audio: + pred_vocal_mag = self.mag_phase_stft(pred_vocal)["mag"] + mix_stft = self.mag_phase_stft(mix_audio) + mix_mag = mix_stft["mag"] + mix_phase = mix_stft["phase"] + pred_accom_mag = mix_mag - pred_vocal_mag + pred_accom_mag[pred_accom_mag < 0] = 0 + pred_accom = librosa.istft(pred_accom_mag*mix_phase,hop_length=self.h_params.preprocess.hopsize,length=len(mix_audio)) + return pred_accom + + def get_compressed_world_parameters_from_audio(self,audio_mono): + print("start: compressed_world_parameters_from_audio") + world_parameters = pw.wav2world(audio_mono.astype("double"), self.sample_rate, frame_period=self.world_frame_period) + + f0 = world_parameters[0] + f0_midi = self.pitch_to_midi(f0) + interpolated_f0_midi,not_pitch = self.interpolate_f0_midi_nan_value(f0_midi) + + spectral_envelope = world_parameters[1] + spectral_envelope = 10*np.log10(spectral_envelope) + + aperiodic = world_parameters[2] + aperiodic = 10.*np.log10(aperiodic**2) + + if self.h_params.preprocess.compress_method_world_parameter == 'mfsc': + print("start: spectral sp_to_mfsc") + compressed_spectral = self.sp_to_mfsc(spectral_envelope, self.h_params.preprocess.num_spectral_coefficients,0.45) + print("start: aperiodic sp_to_mfsc") + compressed_aperiodic = self.sp_to_mfsc(aperiodic, self.h_params.preprocess.num_aperiodic_coefficients,0.45) + elif self.h_params.preprocess.compress_method_world_parameter == 'mgc': + print("start: spectral sp_to_mgc") + compressed_spectral = self.sp_to_mgc(spectral_envelope, self.h_params.preprocess.num_spectral_coefficients,0.45) + print("start: aperiodic sp_to_mgc") + compressed_aperiodic = self.sp_to_mgc(aperiodic, self.h_params.preprocess.num_aperiodic_coefficients,0.45) + + return { "f0": np.transpose(interpolated_f0_midi),"not_pitch":np.transpose(not_pitch.astype(int)), "spectral": np.transpose(compressed_spectral), "aperiodic": np.transpose(compressed_aperiodic) } + + def pitch_to_midi(self,frequency): + midi = 69 + 12 * np.log2(frequency/440) + return midi + + def midi_to_pitch(self,midi): + frequency = 440 * pow(2, (midi - 69) / 12) + return frequency + + def interpolate_f0_midi_nan_value(self,f0_midi): + infinite_conditional_index = np.isinf(f0_midi) + not_infinite_conditional_index = ~infinite_conditional_index + infinite_int_index = infinite_conditional_index.nonzero()[0] + not_infinite_int_index = (not_infinite_conditional_index).nonzero()[0] + + interpolated_f0_midi = f0_midi.copy() + interpolated_f0_midi[infinite_conditional_index] = np.interp(infinite_int_index,not_infinite_int_index,f0_midi[not_infinite_conditional_index]) + + interpolated_f0_midi = interpolated_f0_midi + not_pitch = infinite_conditional_index + + return (interpolated_f0_midi,not_pitch) + + def sp_to_mfsc(self,sp, ndim, fw, noise_floor_db=-120.0): + # helper function, sp->mgc->mfsc in a single step + mgc = self.sp_to_mgc(sp, ndim, fw, noise_floor_db) + mfsc = self.mgc_to_mfsc(mgc) + return mfsc + + def sp_to_mgc(self,sp, ndim, fw, noise_floor_db=-120.0): + # HTS uses -80, but we shift WORLD/STRAIGHT by -20 dB (so would be -100); use a little more headroom (SPTK uses doubles internally, so eps 1e-12 should still be OK) + dtype = sp.dtype + sp = sp.astype(np.float64) # required for pysptk + mgc = np.apply_along_axis(pysptk.mcep, 1, np.atleast_2d(sp), order=ndim-1, alpha=fw, maxiter=0, etype=1, eps=10**(noise_floor_db/10), min_det=0.0, itype=1) + if sp.ndim == 1: + mgc = mgc.flatten() + mgc = mgc.astype(dtype) + return mgc + + def mgc_to_mfsc(self,mgc): + is_1d = mgc.ndim == 1 + mgc = np.atleast_2d(mgc) + ndim = mgc.shape[1] + + # mirror cepstrum + mgc1 = np.concatenate([mgc[:, :], mgc[:, -2:0:-1]], axis=-1) + + # re-scale 'dc' and 'nyquist' cepstral bins (see mcep()) + mgc1[:, 0] *= 2 + mgc1[:, ndim-1] *= 2 + + # fft, truncate, to decibels + mfsc = np.real(np.fft.fft(mgc1)) + mfsc = mfsc[:, :ndim] + mfsc = 10*mfsc/np.log(10) + + if is_1d: + mfsc = mfsc.flatten() + + return mfsc + + def mfsc_to_mgc(self,mfsc): + # mfsc -> mgc -> sp is a much slower alternative to mfsc_to_sp() + is_1d = mfsc.ndim == 1 + mfsc = np.atleast_2d(mfsc) + ndim = mfsc.shape[1] + + mfsc = mfsc/10*np.log(10) + mfsc1 = np.concatenate([mfsc[:, :], mfsc[:, -2:0:-1]], axis=-1) + mgc = np.real(np.fft.ifft(mfsc1)) + mgc[:, 0] /= 2 + mgc[:, ndim-1] /= 2 + mgc = mgc[:, :ndim] + + if is_1d: + mgc = mgc.flatten() + + return mgc + + def mgc_to_sp(self,mgc, spec_size, fw): + dtype = mgc.dtype + mgc = mgc.astype(np.float64) # required for pysptk + fftlen = 2*(spec_size - 1) + sp = np.apply_along_axis(pysptk.mgc2sp, 1, np.atleast_2d(mgc), alpha=fw, gamma=0.0, fftlen=fftlen) + sp = 20*np.real(sp)/np.log(10) + if mgc.ndim == 1: + sp = sp.flatten() + sp = sp.astype(dtype) + return sp + + def get_audio_from_compressed_world_parameters(self, f0, not_pitch, spectral_compressed, aperiodic_compressed): + print("start: audio_from_compressed_world_parameters") + + is_pitch = (1-np.transpose(not_pitch)) + interpolated_f0 = self.midi_to_pitch(np.transpose(f0)) + f0_hz = (interpolated_f0 * is_pitch).astype('double') + + spectral = np.transpose(spectral_compressed) + aperiodic = np.transpose(aperiodic_compressed) + + if self.h_params.preprocess.compress_method_world_parameter == 'mfsc': + print("start: spectral mfsc_to_mgc") + spectral = self.mfsc_to_mgc(spectral) + print("start: aperiodic mfsc_to_mgc") + aperiodic = self.mfsc_to_mgc(aperiodic) + + print("start: spectral mgc_to_sp") + spectral = self.mgc_to_sp(spectral, self.h_params.preprocess.world_parameter_dimension, 0.45) + print("start: aperiodic mgc_to_sp") + aperiodic = self.mgc_to_sp(aperiodic, self.h_params.preprocess.world_parameter_dimension, 0.45) + + spectral = (10**(spectral/10)).astype('double') + aperiodic = (10**(aperiodic/20)).astype('double') + + print("start: synthesize audio") + audio = pw.synthesize(f0_hz,spectral,aperiodic,self.sample_rate,self.world_frame_period) + + return audio + + + def torch_A_weighting(self, frequencies, min_db = -45.0): + """ + Compute A-weighting weights in Decibel scale (codes from librosa) and + transform into amplitude domain (with DB-SPL equation). + + Argument: + frequencies : tensor of frequencies to return amplitude weight + min_db : mininum decibel weight. appropriate min_db value is important, as + exp/log calculation might raise numeric error with float32 type. + + Returns: + weights : tensor of amplitude attenuation weights corresponding to the frequencies tensor. + """ + + # Calculate A-weighting in Decibel scale. + frequencies_squared = frequencies ** 2 + const = torch.tensor([12200, 20.6, 107.7, 737.9]) ** 2.0 + weights_in_db = 2.0 + 20.0 * (torch.log10(const[0]) + 4 * torch.log10(frequencies) + - torch.log10(frequencies_squared + const[0]) + - torch.log10(frequencies_squared + const[1]) + - 0.5 * torch.log10(frequencies_squared + const[2]) + - 0.5 * torch.log10(frequencies_squared + const[3])) + + # Set minimum Decibel weight. + if min_db is not None: + weights_in_db = torch.max(weights_in_db, torch.tensor([min_db], dtype = torch.float32)) + + # Transform Decibel scale weight to amplitude scale weight. + weights = torch.exp(torch.log(torch.tensor([10.], dtype = torch.float32)) * weights_in_db / 10) + + return weights + + + +if __name__ == '__main__': + pa = HParams() + wo = UtilWorldVocoder(pa) + + diff --git a/TorchJaekwon/Util/smoothing_function.py b/TorchJaekwon/Util/smoothing_function.py new file mode 100644 index 0000000000000000000000000000000000000000..f047ebadd523f595134ebbac67d2e84b38a60159 --- /dev/null +++ b/TorchJaekwon/Util/smoothing_function.py @@ -0,0 +1,64 @@ +#%% +from matplotlib import pyplot as plt +import torch +from torch.nn import functional as F +import numpy as np +# %% +def smooth_lbl_loop(lbl, smooth_center, smooth_length, smooth_shape): + lbl_new = lbl.clone().detach().float() + lbl_copy = lbl.clone().detach() + lbl_weight = torch.zeros(smooth_length // 2) + + for i in range(lbl_weight.size(0)): + if smooth_shape == 'square': + lbl_weight[i] = 1 + elif smooth_shape == 'triangle': + lbl_weight[i] = 1 - (i + 1) / (smooth_length // 2 + 1) + elif smooth_shape == 'hann': + lbl_weight[i] = np.hanning(smooth_length + 2)[(smooth_length + 2) // 2 + 1 + i] + + for i in range(1, lbl_weight.size(0) + 1): + if smooth_center: + lbl_new[i:] += lbl_copy[:-i] * lbl_weight[i - 1] + lbl_new[:-i] += lbl_copy[i:] * lbl_weight[i - 1] + else: + lbl_new[i:] += lbl_copy[:-i] + + lbl_new[lbl_new > 1] = 1 + + return lbl_new + +# %% +def smooth_lbl_conv(lbl, smooth_center, smooth_length, smooth_shape): + lbl_new = lbl.clone().detach().cpu().float().unsqueeze(0) # [N, C, L] + lbl_weight = torch.zeros(1, 1, smooth_length) + + for i in range(lbl_weight.size(2)): + if smooth_shape == 'square': + lbl_weight[:, :, i] = 1 + elif smooth_shape == 'triangle': + if i < smooth_length // 2: + lbl_weight[:, :, i] = (i + 1) / (smooth_length // 2 + 1) + else: + lbl_weight[:, :, i] = 1 - (i - smooth_length // 2) / (smooth_length // 2 + 1) + elif smooth_shape == 'hann': + lbl_weight[:, :, i] = np.hanning(smooth_length + 2)[i + 1] + + if smooth_center: + lbl_new = F.conv1d(lbl_new, lbl_weight, bias=None, padding=smooth_length // 2).squeeze() + else: + lbl_new = F.conv1d(lbl_new, lbl_weight, bias=None).squeeze() + + lbl_new[lbl_new > 1] = 1 + + return lbl_new + +# %% +signal = torch.randint(0, 2, (50,)) +signal_smooth_loop = smooth_lbl_loop(signal, True, 3, 'triangle') +signal_smooth_conv = smooth_lbl_conv(signal, True, 3, 'triangle') + +plt.plot(signal) +#plt.plot(signal_smooth_loop, 'r') +plt.plot(signal_smooth_conv.squeeze(), 'g', linestyle='--') +# %% diff --git a/TorchJaekwon/__init__.py b/TorchJaekwon/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/enhance.py b/enhance.py new file mode 100644 index 0000000000000000000000000000000000000000..4f38f1fd26e292d6da75752d2a6ee73588a44e7d --- /dev/null +++ b/enhance.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +""" +Audio super-resolution using FlashSR. + +Independently written wrapper around the FlashSR model by Jaekwon Im and +Juhan Nam (KAIST). Supports files of arbitrary length via windowed processing +with overlap-add. No dependency on torchcodec or FFmpeg -- uses soundfile for +all I/O. + +Paper: https://arxiv.org/abs/2501.10807 +""" + +from __future__ import annotations + +import argparse +import math +import os +import sys +import time +from pathlib import Path + +import numpy as np +import soundfile as sf +import torch +from scipy.signal import resample_poly + +from FlashSR.FlashSR import FlashSR + +# ---- constants ---------------------------------------------------------------- + +TARGET_SR = 48_000 +WINDOW_LEN = 245_760 # samples per model call (5.12 s at 48 kHz) +OVERLAP = 24_000 # crossfade region (0.50 s) +HOP = WINDOW_LEN - OVERLAP # advance per window (4.62 s) + +AUDIO_EXTENSIONS = {".wav", ".flac", ".mp3", ".ogg", ".opus"} + + +# ---- helpers ------------------------------------------------------------------ + +def _load_mono(path: str | Path) -> tuple[np.ndarray, int]: + """Read an audio file, mix to mono, return (float32 array, sample_rate).""" + data, sr = sf.read(str(path), dtype="float32") + if data.ndim == 2: + data = data.mean(axis=1) + return data, sr + + +def _resample_if_needed(audio: np.ndarray, orig_sr: int) -> np.ndarray: + """Polyphase resample to TARGET_SR when the source rate differs.""" + if orig_sr == TARGET_SR: + return audio + return resample_poly(audio, TARGET_SR, orig_sr).astype(np.float32) + + +def _build_fade(length: int) -> torch.Tensor: + """Half-cosine fade-in ramp of *length* samples (0 -> 1).""" + t = torch.linspace(0.0, math.pi / 2, length) + return torch.sin(t) ** 2 # cos^2 fade is smooth at both ends + + +def _pad_to(tensor: torch.Tensor, n: int) -> torch.Tensor: + """Right-zero-pad the last dimension to at least *n* samples.""" + deficit = n - tensor.shape[-1] + if deficit <= 0: + return tensor + return torch.nn.functional.pad(tensor, (0, deficit)) + + +# ---- core --------------------------------------------------------------------- + +def build_model(weights_dir: str | Path, device: torch.device) -> FlashSR: + """Instantiate FlashSR and load pretrained weights.""" + w = Path(weights_dir) + model = FlashSR( + student_ldm_ckpt_path=str(w / "student_ldm.pth"), + sr_vocoder_ckpt_path=str(w / "sr_vocoder.pth"), + autoencoder_ckpt_path=str(w / "vae.pth"), + ) + return model.to(device).eval() + + +@torch.inference_mode() +def enhance( + model: FlashSR, + waveform: np.ndarray, + *, + device: torch.device, + lowpass: bool = False, +) -> np.ndarray: + """ + Run FlashSR on a mono waveform (numpy float32, 48 kHz). + + Long inputs are split into overlapping windows and reassembled with + overlap-add using a raised-cosine crossfade. + + Returns enhanced waveform as numpy float32 at 48 kHz. + """ + signal = torch.from_numpy(waveform).unsqueeze(0) # (1, T) + n_samples = signal.shape[-1] + + # --- short signal: single pass ------------------------------------------- + if n_samples <= WINDOW_LEN: + chunk = _pad_to(signal, WINDOW_LEN).to(device) + out = model(chunk, lowpass_input=lowpass) + return out[0, :n_samples].cpu().numpy() + + # --- long signal: overlap-add -------------------------------------------- + fade = _build_fade(OVERLAP) + accumulator = torch.zeros(n_samples) + norm = torch.zeros(n_samples) + + offset = 0 + while offset < n_samples: + end = min(offset + WINDOW_LEN, n_samples) + segment = signal[:, offset:end] + segment = _pad_to(segment, WINDOW_LEN).to(device) + + enhanced_seg = model(segment, lowpass_input=lowpass).cpu().squeeze(0) + seg_len = min(WINDOW_LEN, n_samples - offset) + enhanced_seg = enhanced_seg[:seg_len] + + # per-sample weights: 1.0 everywhere, faded in at overlap boundary + w = torch.ones(seg_len) + if offset > 0 and seg_len > OVERLAP: + w[:OVERLAP] = fade + + accumulator[offset : offset + seg_len] += enhanced_seg * w + norm[offset : offset + seg_len] += w + offset += HOP + + norm.clamp_(min=1e-8) + return (accumulator / norm).numpy() + + +# ---- file-level convenience --------------------------------------------------- + +def enhance_file( + model: FlashSR, + src: str | Path, + dst: str | Path, + *, + device: torch.device, + lowpass: bool = False, +) -> float: + """Enhance one file. Returns duration in seconds.""" + raw, sr = _load_mono(src) + audio = _resample_if_needed(raw, sr) + result = enhance(model, audio, device=device, lowpass=lowpass) + os.makedirs(os.path.dirname(dst) or ".", exist_ok=True) + sf.write(str(dst), result, TARGET_SR) + return len(audio) / TARGET_SR + + +def collect_audio_files(root: str | Path) -> list[Path]: + """Recursively find audio files under *root*.""" + root = Path(root) + return sorted(p for p in root.rglob("*") if p.suffix.lower() in AUDIO_EXTENSIONS) + + +# ---- CLI ---------------------------------------------------------------------- + +def cli() -> None: + ap = argparse.ArgumentParser( + description="FlashSR audio super-resolution (by Im & Nam, KAIST)") + ap.add_argument("--input", "-i", required=True, + help="Input audio file or directory") + ap.add_argument("--output", "-o", required=True, + help="Output file or directory") + ap.add_argument("--weights", "-w", default="./weights", + help="Directory containing the three .pth weight files") + ap.add_argument("--lowpass", action="store_true", + help="Apply lowpass filter before enhancement") + ap.add_argument("--device", default="cuda", + help="Torch device (default: cuda)") + args = ap.parse_args() + + dev = torch.device(args.device if torch.cuda.is_available() else "cpu") + print(f"Device: {dev}") + + print("Loading model...") + t0 = time.monotonic() + model = build_model(args.weights, dev) + print(f"Loaded in {time.monotonic() - t0:.1f}s") + + # Resolve inputs + inp = Path(args.input) + out = Path(args.output) + + if inp.is_dir(): + files = collect_audio_files(inp) + if not files: + sys.exit(f"No audio files found in {inp}") + pairs = [(f, out / f.relative_to(inp)) for f in files] + else: + pairs = [(inp, out)] + + total_dur = 0.0 + t_start = time.monotonic() + + for idx, (src, dst) in enumerate(pairs, 1): + print(f"[{idx}/{len(pairs)}] {src} -> {dst}") + dur = enhance_file(model, src, dst, device=dev, lowpass=args.lowpass) + total_dur += dur + print(f" {dur:.1f}s of audio") + + elapsed = time.monotonic() - t_start + rtf = total_dur / elapsed if elapsed > 0 else 0 + print(f"\nDone: {len(pairs)} file(s), {total_dur:.1f}s audio, " + f"{elapsed:.1f}s wall-clock ({rtf:.1f}x realtime)") + + +if __name__ == "__main__": + cli() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..62a875a6b68a25e86db94759015644e9fc54a87d --- /dev/null +++ b/setup.py @@ -0,0 +1,21 @@ +from setuptools import setup, find_packages + +setup( + name="FlashSRInfer", + version="0.1", + packages=find_packages(), + install_requires=[ + "numpy", + "tqdm", + "psutil", + "pyyaml", + "matplotlib", + "librosa", + "einops", + "soundfile", + "scipy", + ], + description="Redistribution of FlashSR audio super-resolution inference " + "(original authors: Jaekwon Im and Juhan Nam, KAIST)", + keywords="audio super-resolution speech enhancement", +)