Spaces:
Build error
Build error
| import torch | |
| import numpy as np | |
| import PIL.Image as Image | |
| import torchvision.transforms as transforms | |
| import torch.nn.functional as F | |
| from typing import Optional, Tuple, Union | |
| def morphological_open(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor: | |
| """ | |
| Perform morphological opening on a 2D torch tensor (image). | |
| Args: | |
| image (torch.Tensor): image to open | |
| kernel_size (int): size of the structuring element - roughly the size of hole to be opened | |
| Returns: | |
| torch.Tensor: The opened image. | |
| """ | |
| kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device) | |
| eroded = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2) | |
| eroded = (eroded > 0).float() | |
| dilated = F.conv2d(eroded, kernel, stride=1, padding=kernel_size // 2) | |
| return (dilated > 0).float() | |
| def morphological_close(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor: | |
| """ | |
| Perform morphological closing on a 2D torch tensor (image). | |
| Args: | |
| image (torch.Tensor): image to close | |
| kernel_size (int): size of the structuring element - roughly the size of hole to be closed | |
| Returns: | |
| torch.Tensor: The closed image. | |
| """ | |
| kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device) | |
| dilated = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2) | |
| dilated = (dilated > 0).float() | |
| eroded = F.conv2d(dilated, kernel, stride=1, padding=kernel_size // 2) | |
| return (eroded > 0).float() | |
| def gaussian_convolve(image: torch.Tensor, kernel_size: int = 5, sigma: float = 1.0) -> torch.Tensor: | |
| """ | |
| Gaussian Convolution to smooth image | |
| Args: | |
| image (torch.Tensor): image to convolve | |
| kernel_size (int): size of the Gaussian kernel | |
| sigma (float): standard deviation of the Gaussian distribution | |
| Returns: | |
| torch.Tensor: The convolved image. | |
| """ | |
| x = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32) | |
| y = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32) | |
| x, y = torch.meshgrid(x, y) | |
| kernel = torch.exp(-(x**2 + y**2) / (2 * sigma**2)) | |
| kernel = kernel / kernel.sum() | |
| # Apply the Gaussian kernel | |
| return F.conv2d(image.unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0), stride=1, padding=kernel_size // 2) | |
| def hysteresis_filter(image: torch.Tensor, low_threshold: float, high_threshold: float) -> torch.Tensor: | |
| """ | |
| Hysteresis Filter Function - for Canny Edge detection | |
| Args: | |
| image (torch.Tensor): image to process | |
| low_threshold (float): low threshold for hysteresis | |
| high_threshold (float): high threshold for hysteresis | |
| Returns: | |
| edge (torch.Tensor): The edges detected in the image. | |
| """ | |
| edges = (image > high_threshold).float() | |
| # Perform hysteresis thresholding | |
| edges = torch.where(image > low_threshold, edges, 0) | |
| return edges | |
| def non_maxima_suppression_2d( | |
| image: torch.Tensor, | |
| kernel_size: int = 3, | |
| threshold: Optional[float] = None, | |
| return_mask: bool = False | |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
| """ | |
| Perform non-maxima suppression on a 2D torch tensor (image). | |
| Args: | |
| image (torch.Tensor): Input tensor of shape (H, W) or (B, C, H, W) or (C, H, W) | |
| kernel_size (int): Size of the local neighborhood for maxima detection (default: 3) | |
| threshold (float, optional): Minimum value threshold for considering pixels | |
| return_mask (bool): If True, return both suppressed image and binary mask | |
| Returns: | |
| torch.Tensor: Image with non-maxima suppressed | |
| torch.Tensor (optional): Binary mask of local maxima if return_mask=True | |
| """ | |
| original_shape = image.shape | |
| # Handle different input shapes | |
| if len(image.shape) == 2: # (H, W) | |
| image = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W) | |
| elif len(image.shape) == 3: # (C, H, W) | |
| image = image.unsqueeze(0) # (1, C, H, W) | |
| elif len(image.shape) == 4: # (B, C, H, W) | |
| pass | |
| else: | |
| raise ValueError(f"Unsupported tensor shape: {original_shape}") | |
| batch_size, channels, height, width = image.shape | |
| # Apply threshold if specified | |
| if threshold is not None: | |
| image = torch.where(image >= threshold, image, torch.tensor(0.0, device=image.device)) | |
| # Perform max pooling to find local maxima | |
| padding = kernel_size // 2 | |
| max_pooled = F.max_pool2d(image, kernel_size=kernel_size, stride=1, padding=padding) | |
| # Create mask where original values equal max pooled values (local maxima) | |
| mask = (image == max_pooled) & (image > 0) | |
| # Apply non-maxima suppression | |
| suppressed = image * mask.float() | |
| # Reshape back to original shape | |
| if len(original_shape) == 2: | |
| suppressed = suppressed.squeeze(0).squeeze(0) | |
| mask = mask.squeeze(0).squeeze(0) | |
| elif len(original_shape) == 3: | |
| suppressed = suppressed.squeeze(0) | |
| mask = mask.squeeze(0) | |
| if return_mask: | |
| return suppressed, mask | |
| return suppressed | |
| def non_maxima_suppression_with_orientation( | |
| magnitude: torch.Tensor, | |
| orientation: torch.Tensor, | |
| threshold: Optional[float] = None | |
| ) -> torch.Tensor: | |
| """ | |
| Perform oriented non-maxima suppression (commonly used in edge detection). | |
| Args: | |
| magnitude (torch.Tensor): Gradient magnitude tensor of shape (H, W) or (B, C, H, W) | |
| orientation (torch.Tensor): Gradient orientation tensor (in radians) of same shape | |
| threshold (float, optional): Minimum magnitude threshold | |
| Returns: | |
| torch.Tensor: Non-maxima suppressed magnitude | |
| """ | |
| original_shape = magnitude.shape | |
| # Handle different input shapes | |
| if len(magnitude.shape) == 2: | |
| magnitude = magnitude.unsqueeze(0).unsqueeze(0) | |
| orientation = orientation.unsqueeze(0).unsqueeze(0) | |
| elif len(magnitude.shape) == 3: | |
| magnitude = magnitude.unsqueeze(0) | |
| orientation = orientation.unsqueeze(0) | |
| batch_size, channels, height, width = magnitude.shape | |
| device = magnitude.device | |
| # Apply threshold if specified | |
| if threshold is not None: | |
| magnitude = torch.where(magnitude >= threshold, magnitude, torch.tensor(0.0, device=device)) | |
| # Convert orientation to degrees and normalize to [0, 180) | |
| angle = torch.rad2deg(orientation) % 180 | |
| # Create padded magnitude for neighbor comparison | |
| mag_padded = F.pad(magnitude, (1, 1, 1, 1), mode='constant', value=0) | |
| # Initialize output | |
| suppressed = torch.zeros_like(magnitude) | |
| # Define 8-connectivity neighbors | |
| for b in range(batch_size): | |
| for c in range(channels): | |
| mag = magnitude[b, c] | |
| ang = angle[b, c] | |
| mag_pad = mag_padded[b, c] | |
| for i in range(1, height + 1): | |
| for j in range(1, width + 1): | |
| current_mag = mag_pad[i, j] | |
| current_angle = ang[i-1, j-1] | |
| if current_mag == 0: | |
| continue | |
| # Determine interpolation direction based on angle | |
| if (0 <= current_angle < 22.5) or (157.5 <= current_angle < 180): | |
| # Horizontal direction (0°) | |
| neighbor1 = mag_pad[i, j-1] | |
| neighbor2 = mag_pad[i, j+1] | |
| elif 22.5 <= current_angle < 67.5: | |
| # Diagonal direction (45°) | |
| neighbor1 = mag_pad[i-1, j+1] | |
| neighbor2 = mag_pad[i+1, j-1] | |
| elif 67.5 <= current_angle < 112.5: | |
| # Vertical direction (90°) | |
| neighbor1 = mag_pad[i-1, j] | |
| neighbor2 = mag_pad[i+1, j] | |
| else: # 112.5 <= current_angle < 157.5 | |
| # Diagonal direction (135°) | |
| neighbor1 = mag_pad[i-1, j-1] | |
| neighbor2 = mag_pad[i+1, j+1] | |
| # Keep pixel if it's a local maximum | |
| if current_mag >= neighbor1 and current_mag >= neighbor2: | |
| suppressed[b, c, i-1, j-1] = current_mag | |
| # Reshape back to original shape | |
| if len(original_shape) == 2: | |
| suppressed = suppressed.squeeze(0).squeeze(0) | |
| elif len(original_shape) == 3: | |
| suppressed = suppressed.squeeze(0) | |
| return suppressed | |
| def adaptive_non_maxima_suppression( | |
| image: torch.Tensor, | |
| num_points: int, | |
| min_distance: int = 5, | |
| threshold: Optional[float] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Adaptive non-maxima suppression that selects a fixed number of strongest points | |
| while maintaining minimum distance between them. | |
| Args: | |
| image (torch.Tensor): Input tensor of shape (H, W) | |
| num_points (int): Number of points to select | |
| min_distance (int): Minimum distance between selected points | |
| threshold (float, optional): Minimum value threshold | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: Coordinates (y, x) and values of selected points | |
| """ | |
| if len(image.shape) != 2: | |
| raise ValueError("Input must be a 2D tensor") | |
| height, width = image.shape | |
| device = image.device | |
| # Apply threshold if specified | |
| if threshold is not None: | |
| image = torch.where(image >= threshold, image, torch.tensor(0.0, device=device)) | |
| # Find all local maxima using simple NMS | |
| nms_result = non_maxima_suppression_2d(image, kernel_size=3) | |
| # Get coordinates and values of all local maxima | |
| y_coords, x_coords = torch.nonzero(nms_result > 0, as_tuple=True) | |
| values = nms_result[y_coords, x_coords] | |
| if len(values) == 0: | |
| return torch.empty((0, 2), device=device), torch.empty(0, device=device) | |
| # Sort by strength (descending) | |
| sorted_indices = torch.argsort(values, descending=True) | |
| y_coords = y_coords[sorted_indices] | |
| x_coords = x_coords[sorted_indices] | |
| values = values[sorted_indices] | |
| # Select points with minimum distance constraint | |
| selected_coords = [] | |
| selected_values = [] | |
| for i in range(len(values)): | |
| if len(selected_coords) >= num_points: | |
| break | |
| current_y, current_x = y_coords[i].item(), x_coords[i].item() | |
| current_val = values[i].item() | |
| # Check distance to all previously selected points | |
| valid = True | |
| for sel_y, sel_x in selected_coords: | |
| distance = ((current_y - sel_y) ** 2 + (current_x - sel_x) ** 2) ** 0.5 | |
| if distance < min_distance: | |
| valid = False | |
| break | |
| if valid: | |
| selected_coords.append((current_y, current_x)) | |
| selected_values.append(current_val) | |
| if selected_coords: | |
| coords_tensor = torch.tensor(selected_coords, device=device, dtype=torch.float32) | |
| values_tensor = torch.tensor(selected_values, device=device, dtype=torch.float32) | |
| else: | |
| coords_tensor = torch.empty((0, 2), device=device) | |
| values_tensor = torch.empty(0, device=device) | |
| return coords_tensor, values_tensor | |