File size: 4,010 Bytes
51e2f90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
'''
SCNet - great paper, great implementation
https://arxiv.org/pdf/2401.13276.pdf
https://github.com/amanteur/SCNet-PyTorch
'''
from typing import List, Tuple, Union
import torch
def create_intervals(
splits: List[Union[float, int]]
) -> List[Union[Tuple[float, float], Tuple[int, int]]]:
"""
Create intervals based on splits provided.
Args:
- splits (List[Union[float, int]]): List of floats or integers representing splits.
Returns:
- List[Union[Tuple[float, float], Tuple[int, int]]]: List of tuples representing intervals.
"""
start = 0
return [(start, start := start + split) for split in splits]
def get_conv_output_shape(
input_shape: int,
kernel_size: int = 1,
padding: int = 0,
dilation: int = 1,
stride: int = 1,
) -> int:
"""
Compute the output shape of a convolutional layer.
Args:
- input_shape (int): Input shape.
- kernel_size (int, optional): Kernel size of the convolution. Default is 1.
- padding (int, optional): Padding size. Default is 0.
- dilation (int, optional): Dilation factor. Default is 1.
- stride (int, optional): Stride value. Default is 1.
Returns:
- int: Output shape.
"""
return int(
(input_shape + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
)
def get_convtranspose_output_padding(
input_shape: int,
output_shape: int,
kernel_size: int = 1,
padding: int = 0,
dilation: int = 1,
stride: int = 1,
) -> int:
"""
Compute the output padding for a convolution transpose operation.
Args:
- input_shape (int): Input shape.
- output_shape (int): Desired output shape.
- kernel_size (int, optional): Kernel size of the convolution. Default is 1.
- padding (int, optional): Padding size. Default is 0.
- dilation (int, optional): Dilation factor. Default is 1.
- stride (int, optional): Stride value. Default is 1.
Returns:
- int: Output padding.
"""
return (
output_shape
- (input_shape - 1) * stride
+ 2 * padding
- dilation * (kernel_size - 1)
- 1
)
def compute_sd_layer_shapes(
input_shape: int,
bandsplit_ratios: List[float],
downsample_strides: List[int],
n_layers: int,
) -> Tuple[List[List[int]], List[List[Tuple[int, int]]]]:
"""
Compute the shapes for the subband layers.
Args:
- input_shape (int): Input shape.
- bandsplit_ratios (List[float]): Ratios for splitting the frequency bands.
- downsample_strides (List[int]): Strides for downsampling in each layer.
- n_layers (int): Number of layers.
Returns:
- Tuple[List[List[int]], List[List[Tuple[int, int]]]]: Tuple containing subband shapes and convolution shapes.
"""
bandsplit_shapes_list = []
conv2d_shapes_list = []
for _ in range(n_layers):
bandsplit_intervals = create_intervals(bandsplit_ratios)
bandsplit_shapes = [
int(right * input_shape) - int(left * input_shape)
for left, right in bandsplit_intervals
]
conv2d_shapes = [
get_conv_output_shape(bs, stride=ds)
for bs, ds in zip(bandsplit_shapes, downsample_strides)
]
input_shape = sum(conv2d_shapes)
bandsplit_shapes_list.append(bandsplit_shapes)
conv2d_shapes_list.append(create_intervals(conv2d_shapes))
return bandsplit_shapes_list, conv2d_shapes_list
def compute_gcr(subband_shapes: List[List[int]]) -> float:
"""
Compute the global compression ratio.
Args:
- subband_shapes (List[List[int]]): List of subband shapes.
Returns:
- float: Global compression ratio.
"""
t = torch.Tensor(subband_shapes)
gcr = torch.stack(
[(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)]
).mean()
return float(gcr) |