|
r"""Custom layers used in metrics computations""" |
|
import torch |
|
from typing import Optional |
|
|
|
from .filters import hann_filter |
|
|
|
|
|
class L2Pool2d(torch.nn.Module): |
|
r"""Applies L2 pooling with Hann window of size 3x3 |
|
Args: |
|
x: Tensor with shape (N, C, H, W)""" |
|
EPS = 1e-12 |
|
|
|
def __init__(self, kernel_size: int = 3, stride: int = 2, padding=1) -> None: |
|
super().__init__() |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
|
|
self.kernel: Optional[torch.Tensor] = None |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.kernel is None: |
|
C = x.size(1) |
|
self.kernel = hann_filter(self.kernel_size).repeat((C, 1, 1, 1)).to(x) |
|
|
|
out = torch.nn.functional.conv2d( |
|
x ** 2, self.kernel, |
|
stride=self.stride, |
|
padding=self.padding, |
|
groups=x.shape[1] |
|
) |
|
return (out + self.EPS).sqrt() |
|
|