''' SCNet - great paper, great implementation https://arxiv.org/pdf/2401.13276.pdf https://github.com/amanteur/SCNet-PyTorch ''' from typing import List, Tuple, Union import torch def create_intervals( splits: List[Union[float, int]] ) -> List[Union[Tuple[float, float], Tuple[int, int]]]: """ Create intervals based on splits provided. Args: - splits (List[Union[float, int]]): List of floats or integers representing splits. Returns: - List[Union[Tuple[float, float], Tuple[int, int]]]: List of tuples representing intervals. """ start = 0 return [(start, start := start + split) for split in splits] def get_conv_output_shape( input_shape: int, kernel_size: int = 1, padding: int = 0, dilation: int = 1, stride: int = 1, ) -> int: """ Compute the output shape of a convolutional layer. Args: - input_shape (int): Input shape. - kernel_size (int, optional): Kernel size of the convolution. Default is 1. - padding (int, optional): Padding size. Default is 0. - dilation (int, optional): Dilation factor. Default is 1. - stride (int, optional): Stride value. Default is 1. Returns: - int: Output shape. """ return int( (input_shape + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 ) def get_convtranspose_output_padding( input_shape: int, output_shape: int, kernel_size: int = 1, padding: int = 0, dilation: int = 1, stride: int = 1, ) -> int: """ Compute the output padding for a convolution transpose operation. Args: - input_shape (int): Input shape. - output_shape (int): Desired output shape. - kernel_size (int, optional): Kernel size of the convolution. Default is 1. - padding (int, optional): Padding size. Default is 0. - dilation (int, optional): Dilation factor. Default is 1. - stride (int, optional): Stride value. Default is 1. Returns: - int: Output padding. """ return ( output_shape - (input_shape - 1) * stride + 2 * padding - dilation * (kernel_size - 1) - 1 ) def compute_sd_layer_shapes( input_shape: int, bandsplit_ratios: List[float], downsample_strides: List[int], n_layers: int, ) -> Tuple[List[List[int]], List[List[Tuple[int, int]]]]: """ Compute the shapes for the subband layers. Args: - input_shape (int): Input shape. - bandsplit_ratios (List[float]): Ratios for splitting the frequency bands. - downsample_strides (List[int]): Strides for downsampling in each layer. - n_layers (int): Number of layers. Returns: - Tuple[List[List[int]], List[List[Tuple[int, int]]]]: Tuple containing subband shapes and convolution shapes. """ bandsplit_shapes_list = [] conv2d_shapes_list = [] for _ in range(n_layers): bandsplit_intervals = create_intervals(bandsplit_ratios) bandsplit_shapes = [ int(right * input_shape) - int(left * input_shape) for left, right in bandsplit_intervals ] conv2d_shapes = [ get_conv_output_shape(bs, stride=ds) for bs, ds in zip(bandsplit_shapes, downsample_strides) ] input_shape = sum(conv2d_shapes) bandsplit_shapes_list.append(bandsplit_shapes) conv2d_shapes_list.append(create_intervals(conv2d_shapes)) return bandsplit_shapes_list, conv2d_shapes_list def compute_gcr(subband_shapes: List[List[int]]) -> float: """ Compute the global compression ratio. Args: - subband_shapes (List[List[int]]): List of subband shapes. Returns: - float: Global compression ratio. """ t = torch.Tensor(subband_shapes) gcr = torch.stack( [(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)] ).mean() return float(gcr)