# ztrain/signal.py # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted import torch def gaussian_kernel(size, sigma=1.0): """ Generates a 2D Gaussian kernel using PyTorch. Parameters: - size: The size of the kernel (an integer). It's recommended to use an odd number to have a central pixel. - sigma: The standard deviation of the Gaussian distribution. Returns: - A 2D PyTorch tensor representing the Gaussian kernel. """ size = int(size) // 2 x, y = torch.meshgrid(torch.arange(-size, size+1), torch.arange(-size, size+1)) g = torch.exp(-(x**2 + y**2) / (2 * sigma**2)) return g / g.sum() def laplacian_kernel(size, scale=1.0): """ Creates a Laplacian kernel for edge detection with an adjustable size and scale factor. Parameters: - size: The size of the kernel (an integer). It's recommended to use an odd number to ensure a central pixel. - scale: A float that adjusts the intensity of the edge detection effect. Returns: - A 2D PyTorch tensor representing the scaled Laplacian kernel. """ if size % 2 == 0: raise ValueError("Size must be odd.") # Initialize the kernel with zeros kernel = torch.zeros((size, size), dtype=torch.float32) # Set the center pixel kernel[size // 2, size // 2] = -4.0 # Set the immediate neighbors kernel[size // 2, size // 2 - 1] = kernel[size // 2, size // 2 + 1] = 1.0 kernel[size // 2 - 1, size // 2] = kernel[size // 2 + 1, size // 2] = 1.0 # For larger kernels, adjust the outer pixels (this simplistic approach might need refinement for larger sizes) if size > 3: for i in range(size): for j in range(size): if i == 0 or i == size - 1 or j == 0 or j == size - 1: kernel[i, j] = 1.0 # Apply the scale factor kernel *= scale # Adjust the kernel so that its sum is 0 center = size // 2 kernel[center, center] = -torch.sum(kernel) + kernel[center, center] return kernel def fftshift(input): """ Reorients the FFT output so the zero-frequency component is at the center. Parameters: - input: A 2D tensor representing the FFT output. Returns: - A 2D tensor with the zero-frequency component shifted to the center. """ # For even dimensions, we split at dim_size // 2. For odd dimensions, we need to do (dim_size + 1) // 2 for dim in range(2): # assuming input is 2D n = input.shape[dim] half = (n + 1) // 2 input = torch.roll(input, shifts=half, dims=dim) return input