poiqazwsx's picture
Upload 57 files
51e2f90
raw
history blame contribute delete
No virus
4.01 kB
'''
SCNet - great paper, great implementation
https://arxiv.org/pdf/2401.13276.pdf
https://github.com/amanteur/SCNet-PyTorch
'''
from typing import List, Tuple, Union
import torch
def create_intervals(
splits: List[Union[float, int]]
) -> List[Union[Tuple[float, float], Tuple[int, int]]]:
"""
Create intervals based on splits provided.
Args:
- splits (List[Union[float, int]]): List of floats or integers representing splits.
Returns:
- List[Union[Tuple[float, float], Tuple[int, int]]]: List of tuples representing intervals.
"""
start = 0
return [(start, start := start + split) for split in splits]
def get_conv_output_shape(
input_shape: int,
kernel_size: int = 1,
padding: int = 0,
dilation: int = 1,
stride: int = 1,
) -> int:
"""
Compute the output shape of a convolutional layer.
Args:
- input_shape (int): Input shape.
- kernel_size (int, optional): Kernel size of the convolution. Default is 1.
- padding (int, optional): Padding size. Default is 0.
- dilation (int, optional): Dilation factor. Default is 1.
- stride (int, optional): Stride value. Default is 1.
Returns:
- int: Output shape.
"""
return int(
(input_shape + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
)
def get_convtranspose_output_padding(
input_shape: int,
output_shape: int,
kernel_size: int = 1,
padding: int = 0,
dilation: int = 1,
stride: int = 1,
) -> int:
"""
Compute the output padding for a convolution transpose operation.
Args:
- input_shape (int): Input shape.
- output_shape (int): Desired output shape.
- kernel_size (int, optional): Kernel size of the convolution. Default is 1.
- padding (int, optional): Padding size. Default is 0.
- dilation (int, optional): Dilation factor. Default is 1.
- stride (int, optional): Stride value. Default is 1.
Returns:
- int: Output padding.
"""
return (
output_shape
- (input_shape - 1) * stride
+ 2 * padding
- dilation * (kernel_size - 1)
- 1
)
def compute_sd_layer_shapes(
input_shape: int,
bandsplit_ratios: List[float],
downsample_strides: List[int],
n_layers: int,
) -> Tuple[List[List[int]], List[List[Tuple[int, int]]]]:
"""
Compute the shapes for the subband layers.
Args:
- input_shape (int): Input shape.
- bandsplit_ratios (List[float]): Ratios for splitting the frequency bands.
- downsample_strides (List[int]): Strides for downsampling in each layer.
- n_layers (int): Number of layers.
Returns:
- Tuple[List[List[int]], List[List[Tuple[int, int]]]]: Tuple containing subband shapes and convolution shapes.
"""
bandsplit_shapes_list = []
conv2d_shapes_list = []
for _ in range(n_layers):
bandsplit_intervals = create_intervals(bandsplit_ratios)
bandsplit_shapes = [
int(right * input_shape) - int(left * input_shape)
for left, right in bandsplit_intervals
]
conv2d_shapes = [
get_conv_output_shape(bs, stride=ds)
for bs, ds in zip(bandsplit_shapes, downsample_strides)
]
input_shape = sum(conv2d_shapes)
bandsplit_shapes_list.append(bandsplit_shapes)
conv2d_shapes_list.append(create_intervals(conv2d_shapes))
return bandsplit_shapes_list, conv2d_shapes_list
def compute_gcr(subband_shapes: List[List[int]]) -> float:
"""
Compute the global compression ratio.
Args:
- subband_shapes (List[List[int]]): List of subband shapes.
Returns:
- float: Global compression ratio.
"""
t = torch.Tensor(subband_shapes)
gcr = torch.stack(
[(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)]
).mean()
return float(gcr)