SFM_Inference_Demo / util /post_processing.py
Anirudh Bhalekar
added models and util folder
a3f0d6c
raw
history blame
11.5 kB
import torch
import numpy as np
import PIL.Image as Image
import torchvision.transforms as transforms
import torch.nn.functional as F
from typing import Optional, Tuple, Union
def morphological_open(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor:
"""
Perform morphological opening on a 2D torch tensor (image).
Args:
image (torch.Tensor): image to open
kernel_size (int): size of the structuring element - roughly the size of hole to be opened
Returns:
torch.Tensor: The opened image.
"""
kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device)
eroded = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2)
eroded = (eroded > 0).float()
dilated = F.conv2d(eroded, kernel, stride=1, padding=kernel_size // 2)
return (dilated > 0).float()
def morphological_close(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor:
"""
Perform morphological closing on a 2D torch tensor (image).
Args:
image (torch.Tensor): image to close
kernel_size (int): size of the structuring element - roughly the size of hole to be closed
Returns:
torch.Tensor: The closed image.
"""
kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device)
dilated = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2)
dilated = (dilated > 0).float()
eroded = F.conv2d(dilated, kernel, stride=1, padding=kernel_size // 2)
return (eroded > 0).float()
def gaussian_convolve(image: torch.Tensor, kernel_size: int = 5, sigma: float = 1.0) -> torch.Tensor:
"""
Gaussian Convolution to smooth image
Args:
image (torch.Tensor): image to convolve
kernel_size (int): size of the Gaussian kernel
sigma (float): standard deviation of the Gaussian distribution
Returns:
torch.Tensor: The convolved image.
"""
x = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32)
y = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32)
x, y = torch.meshgrid(x, y)
kernel = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
kernel = kernel / kernel.sum()
# Apply the Gaussian kernel
return F.conv2d(image.unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0), stride=1, padding=kernel_size // 2)
def hysteresis_filter(image: torch.Tensor, low_threshold: float, high_threshold: float) -> torch.Tensor:
"""
Hysteresis Filter Function - for Canny Edge detection
Args:
image (torch.Tensor): image to process
low_threshold (float): low threshold for hysteresis
high_threshold (float): high threshold for hysteresis
Returns:
edge (torch.Tensor): The edges detected in the image.
"""
edges = (image > high_threshold).float()
# Perform hysteresis thresholding
edges = torch.where(image > low_threshold, edges, 0)
return edges
def non_maxima_suppression_2d(
image: torch.Tensor,
kernel_size: int = 3,
threshold: Optional[float] = None,
return_mask: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Perform non-maxima suppression on a 2D torch tensor (image).
Args:
image (torch.Tensor): Input tensor of shape (H, W) or (B, C, H, W) or (C, H, W)
kernel_size (int): Size of the local neighborhood for maxima detection (default: 3)
threshold (float, optional): Minimum value threshold for considering pixels
return_mask (bool): If True, return both suppressed image and binary mask
Returns:
torch.Tensor: Image with non-maxima suppressed
torch.Tensor (optional): Binary mask of local maxima if return_mask=True
"""
original_shape = image.shape
# Handle different input shapes
if len(image.shape) == 2: # (H, W)
image = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
elif len(image.shape) == 3: # (C, H, W)
image = image.unsqueeze(0) # (1, C, H, W)
elif len(image.shape) == 4: # (B, C, H, W)
pass
else:
raise ValueError(f"Unsupported tensor shape: {original_shape}")
batch_size, channels, height, width = image.shape
# Apply threshold if specified
if threshold is not None:
image = torch.where(image >= threshold, image, torch.tensor(0.0, device=image.device))
# Perform max pooling to find local maxima
padding = kernel_size // 2
max_pooled = F.max_pool2d(image, kernel_size=kernel_size, stride=1, padding=padding)
# Create mask where original values equal max pooled values (local maxima)
mask = (image == max_pooled) & (image > 0)
# Apply non-maxima suppression
suppressed = image * mask.float()
# Reshape back to original shape
if len(original_shape) == 2:
suppressed = suppressed.squeeze(0).squeeze(0)
mask = mask.squeeze(0).squeeze(0)
elif len(original_shape) == 3:
suppressed = suppressed.squeeze(0)
mask = mask.squeeze(0)
if return_mask:
return suppressed, mask
return suppressed
def non_maxima_suppression_with_orientation(
magnitude: torch.Tensor,
orientation: torch.Tensor,
threshold: Optional[float] = None
) -> torch.Tensor:
"""
Perform oriented non-maxima suppression (commonly used in edge detection).
Args:
magnitude (torch.Tensor): Gradient magnitude tensor of shape (H, W) or (B, C, H, W)
orientation (torch.Tensor): Gradient orientation tensor (in radians) of same shape
threshold (float, optional): Minimum magnitude threshold
Returns:
torch.Tensor: Non-maxima suppressed magnitude
"""
original_shape = magnitude.shape
# Handle different input shapes
if len(magnitude.shape) == 2:
magnitude = magnitude.unsqueeze(0).unsqueeze(0)
orientation = orientation.unsqueeze(0).unsqueeze(0)
elif len(magnitude.shape) == 3:
magnitude = magnitude.unsqueeze(0)
orientation = orientation.unsqueeze(0)
batch_size, channels, height, width = magnitude.shape
device = magnitude.device
# Apply threshold if specified
if threshold is not None:
magnitude = torch.where(magnitude >= threshold, magnitude, torch.tensor(0.0, device=device))
# Convert orientation to degrees and normalize to [0, 180)
angle = torch.rad2deg(orientation) % 180
# Create padded magnitude for neighbor comparison
mag_padded = F.pad(magnitude, (1, 1, 1, 1), mode='constant', value=0)
# Initialize output
suppressed = torch.zeros_like(magnitude)
# Define 8-connectivity neighbors
for b in range(batch_size):
for c in range(channels):
mag = magnitude[b, c]
ang = angle[b, c]
mag_pad = mag_padded[b, c]
for i in range(1, height + 1):
for j in range(1, width + 1):
current_mag = mag_pad[i, j]
current_angle = ang[i-1, j-1]
if current_mag == 0:
continue
# Determine interpolation direction based on angle
if (0 <= current_angle < 22.5) or (157.5 <= current_angle < 180):
# Horizontal direction (0°)
neighbor1 = mag_pad[i, j-1]
neighbor2 = mag_pad[i, j+1]
elif 22.5 <= current_angle < 67.5:
# Diagonal direction (45°)
neighbor1 = mag_pad[i-1, j+1]
neighbor2 = mag_pad[i+1, j-1]
elif 67.5 <= current_angle < 112.5:
# Vertical direction (90°)
neighbor1 = mag_pad[i-1, j]
neighbor2 = mag_pad[i+1, j]
else: # 112.5 <= current_angle < 157.5
# Diagonal direction (135°)
neighbor1 = mag_pad[i-1, j-1]
neighbor2 = mag_pad[i+1, j+1]
# Keep pixel if it's a local maximum
if current_mag >= neighbor1 and current_mag >= neighbor2:
suppressed[b, c, i-1, j-1] = current_mag
# Reshape back to original shape
if len(original_shape) == 2:
suppressed = suppressed.squeeze(0).squeeze(0)
elif len(original_shape) == 3:
suppressed = suppressed.squeeze(0)
return suppressed
def adaptive_non_maxima_suppression(
image: torch.Tensor,
num_points: int,
min_distance: int = 5,
threshold: Optional[float] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Adaptive non-maxima suppression that selects a fixed number of strongest points
while maintaining minimum distance between them.
Args:
image (torch.Tensor): Input tensor of shape (H, W)
num_points (int): Number of points to select
min_distance (int): Minimum distance between selected points
threshold (float, optional): Minimum value threshold
Returns:
Tuple[torch.Tensor, torch.Tensor]: Coordinates (y, x) and values of selected points
"""
if len(image.shape) != 2:
raise ValueError("Input must be a 2D tensor")
height, width = image.shape
device = image.device
# Apply threshold if specified
if threshold is not None:
image = torch.where(image >= threshold, image, torch.tensor(0.0, device=device))
# Find all local maxima using simple NMS
nms_result = non_maxima_suppression_2d(image, kernel_size=3)
# Get coordinates and values of all local maxima
y_coords, x_coords = torch.nonzero(nms_result > 0, as_tuple=True)
values = nms_result[y_coords, x_coords]
if len(values) == 0:
return torch.empty((0, 2), device=device), torch.empty(0, device=device)
# Sort by strength (descending)
sorted_indices = torch.argsort(values, descending=True)
y_coords = y_coords[sorted_indices]
x_coords = x_coords[sorted_indices]
values = values[sorted_indices]
# Select points with minimum distance constraint
selected_coords = []
selected_values = []
for i in range(len(values)):
if len(selected_coords) >= num_points:
break
current_y, current_x = y_coords[i].item(), x_coords[i].item()
current_val = values[i].item()
# Check distance to all previously selected points
valid = True
for sel_y, sel_x in selected_coords:
distance = ((current_y - sel_y) ** 2 + (current_x - sel_x) ** 2) ** 0.5
if distance < min_distance:
valid = False
break
if valid:
selected_coords.append((current_y, current_x))
selected_values.append(current_val)
if selected_coords:
coords_tensor = torch.tensor(selected_coords, device=device, dtype=torch.float32)
values_tensor = torch.tensor(selected_values, device=device, dtype=torch.float32)
else:
coords_tensor = torch.empty((0, 2), device=device)
values_tensor = torch.empty(0, device=device)
return coords_tensor, values_tensor