maldv's picture
Upload folder using huggingface_hub
b59223f verified
raw
history blame
2.65 kB
# ztrain/signal.py
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
import torch
def gaussian_kernel(size, sigma=1.0):
"""
Generates a 2D Gaussian kernel using PyTorch.
Parameters:
- size: The size of the kernel (an integer). It's recommended to use an odd number
to have a central pixel.
- sigma: The standard deviation of the Gaussian distribution.
Returns:
- A 2D PyTorch tensor representing the Gaussian kernel.
"""
size = int(size) // 2
x, y = torch.meshgrid(torch.arange(-size, size+1), torch.arange(-size, size+1))
g = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
return g / g.sum()
def laplacian_kernel(size, scale=1.0):
"""
Creates a Laplacian kernel for edge detection with an adjustable size and scale factor.
Parameters:
- size: The size of the kernel (an integer). It's recommended to use an odd number
to ensure a central pixel.
- scale: A float that adjusts the intensity of the edge detection effect.
Returns:
- A 2D PyTorch tensor representing the scaled Laplacian kernel.
"""
if size % 2 == 0:
raise ValueError("Size must be odd.")
# Initialize the kernel with zeros
kernel = torch.zeros((size, size), dtype=torch.float32)
# Set the center pixel
kernel[size // 2, size // 2] = -4.0
# Set the immediate neighbors
kernel[size // 2, size // 2 - 1] = kernel[size // 2, size // 2 + 1] = 1.0
kernel[size // 2 - 1, size // 2] = kernel[size // 2 + 1, size // 2] = 1.0
# For larger kernels, adjust the outer pixels (this simplistic approach might need refinement for larger sizes)
if size > 3:
for i in range(size):
for j in range(size):
if i == 0 or i == size - 1 or j == 0 or j == size - 1:
kernel[i, j] = 1.0
# Apply the scale factor
kernel *= scale
# Adjust the kernel so that its sum is 0
center = size // 2
kernel[center, center] = -torch.sum(kernel) + kernel[center, center]
return kernel
def fftshift(input):
"""
Reorients the FFT output so the zero-frequency component is at the center.
Parameters:
- input: A 2D tensor representing the FFT output.
Returns:
- A 2D tensor with the zero-frequency component shifted to the center.
"""
# For even dimensions, we split at dim_size // 2. For odd dimensions, we need to do (dim_size + 1) // 2
for dim in range(2): # assuming input is 2D
n = input.shape[dim]
half = (n + 1) // 2
input = torch.roll(input, shifts=half, dims=dim)
return input