| ''' | |
| SCNet - great paper, great implementation | |
| https://arxiv.org/pdf/2401.13276.pdf | |
| https://github.com/amanteur/SCNet-PyTorch | |
| ''' | |
| from typing import List, Tuple, Union | |
| import torch | |
| def create_intervals( | |
| splits: List[Union[float, int]] | |
| ) -> List[Union[Tuple[float, float], Tuple[int, int]]]: | |
| """ | |
| Create intervals based on splits provided. | |
| Args: | |
| - splits (List[Union[float, int]]): List of floats or integers representing splits. | |
| Returns: | |
| - List[Union[Tuple[float, float], Tuple[int, int]]]: List of tuples representing intervals. | |
| """ | |
| start = 0 | |
| return [(start, start := start + split) for split in splits] | |
| def get_conv_output_shape( | |
| input_shape: int, | |
| kernel_size: int = 1, | |
| padding: int = 0, | |
| dilation: int = 1, | |
| stride: int = 1, | |
| ) -> int: | |
| """ | |
| Compute the output shape of a convolutional layer. | |
| Args: | |
| - input_shape (int): Input shape. | |
| - kernel_size (int, optional): Kernel size of the convolution. Default is 1. | |
| - padding (int, optional): Padding size. Default is 0. | |
| - dilation (int, optional): Dilation factor. Default is 1. | |
| - stride (int, optional): Stride value. Default is 1. | |
| Returns: | |
| - int: Output shape. | |
| """ | |
| return int( | |
| (input_shape + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 | |
| ) | |
| def get_convtranspose_output_padding( | |
| input_shape: int, | |
| output_shape: int, | |
| kernel_size: int = 1, | |
| padding: int = 0, | |
| dilation: int = 1, | |
| stride: int = 1, | |
| ) -> int: | |
| """ | |
| Compute the output padding for a convolution transpose operation. | |
| Args: | |
| - input_shape (int): Input shape. | |
| - output_shape (int): Desired output shape. | |
| - kernel_size (int, optional): Kernel size of the convolution. Default is 1. | |
| - padding (int, optional): Padding size. Default is 0. | |
| - dilation (int, optional): Dilation factor. Default is 1. | |
| - stride (int, optional): Stride value. Default is 1. | |
| Returns: | |
| - int: Output padding. | |
| """ | |
| return ( | |
| output_shape | |
| - (input_shape - 1) * stride | |
| + 2 * padding | |
| - dilation * (kernel_size - 1) | |
| - 1 | |
| ) | |
| def compute_sd_layer_shapes( | |
| input_shape: int, | |
| bandsplit_ratios: List[float], | |
| downsample_strides: List[int], | |
| n_layers: int, | |
| ) -> Tuple[List[List[int]], List[List[Tuple[int, int]]]]: | |
| """ | |
| Compute the shapes for the subband layers. | |
| Args: | |
| - input_shape (int): Input shape. | |
| - bandsplit_ratios (List[float]): Ratios for splitting the frequency bands. | |
| - downsample_strides (List[int]): Strides for downsampling in each layer. | |
| - n_layers (int): Number of layers. | |
| Returns: | |
| - Tuple[List[List[int]], List[List[Tuple[int, int]]]]: Tuple containing subband shapes and convolution shapes. | |
| """ | |
| bandsplit_shapes_list = [] | |
| conv2d_shapes_list = [] | |
| for _ in range(n_layers): | |
| bandsplit_intervals = create_intervals(bandsplit_ratios) | |
| bandsplit_shapes = [ | |
| int(right * input_shape) - int(left * input_shape) | |
| for left, right in bandsplit_intervals | |
| ] | |
| conv2d_shapes = [ | |
| get_conv_output_shape(bs, stride=ds) | |
| for bs, ds in zip(bandsplit_shapes, downsample_strides) | |
| ] | |
| input_shape = sum(conv2d_shapes) | |
| bandsplit_shapes_list.append(bandsplit_shapes) | |
| conv2d_shapes_list.append(create_intervals(conv2d_shapes)) | |
| return bandsplit_shapes_list, conv2d_shapes_list | |
| def compute_gcr(subband_shapes: List[List[int]]) -> float: | |
| """ | |
| Compute the global compression ratio. | |
| Args: | |
| - subband_shapes (List[List[int]]): List of subband shapes. | |
| Returns: | |
| - float: Global compression ratio. | |
| """ | |
| t = torch.Tensor(subband_shapes) | |
| gcr = torch.stack( | |
| [(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)] | |
| ).mean() | |
| return float(gcr) |