Spaces:
Build error
Build error
File size: 11,493 Bytes
a3f0d6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 |
import torch
import numpy as np
import PIL.Image as Image
import torchvision.transforms as transforms
import torch.nn.functional as F
from typing import Optional, Tuple, Union
def morphological_open(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor:
"""
Perform morphological opening on a 2D torch tensor (image).
Args:
image (torch.Tensor): image to open
kernel_size (int): size of the structuring element - roughly the size of hole to be opened
Returns:
torch.Tensor: The opened image.
"""
kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device)
eroded = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2)
eroded = (eroded > 0).float()
dilated = F.conv2d(eroded, kernel, stride=1, padding=kernel_size // 2)
return (dilated > 0).float()
def morphological_close(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor:
"""
Perform morphological closing on a 2D torch tensor (image).
Args:
image (torch.Tensor): image to close
kernel_size (int): size of the structuring element - roughly the size of hole to be closed
Returns:
torch.Tensor: The closed image.
"""
kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device)
dilated = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2)
dilated = (dilated > 0).float()
eroded = F.conv2d(dilated, kernel, stride=1, padding=kernel_size // 2)
return (eroded > 0).float()
def gaussian_convolve(image: torch.Tensor, kernel_size: int = 5, sigma: float = 1.0) -> torch.Tensor:
"""
Gaussian Convolution to smooth image
Args:
image (torch.Tensor): image to convolve
kernel_size (int): size of the Gaussian kernel
sigma (float): standard deviation of the Gaussian distribution
Returns:
torch.Tensor: The convolved image.
"""
x = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32)
y = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32)
x, y = torch.meshgrid(x, y)
kernel = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
kernel = kernel / kernel.sum()
# Apply the Gaussian kernel
return F.conv2d(image.unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0), stride=1, padding=kernel_size // 2)
def hysteresis_filter(image: torch.Tensor, low_threshold: float, high_threshold: float) -> torch.Tensor:
"""
Hysteresis Filter Function - for Canny Edge detection
Args:
image (torch.Tensor): image to process
low_threshold (float): low threshold for hysteresis
high_threshold (float): high threshold for hysteresis
Returns:
edge (torch.Tensor): The edges detected in the image.
"""
edges = (image > high_threshold).float()
# Perform hysteresis thresholding
edges = torch.where(image > low_threshold, edges, 0)
return edges
def non_maxima_suppression_2d(
image: torch.Tensor,
kernel_size: int = 3,
threshold: Optional[float] = None,
return_mask: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Perform non-maxima suppression on a 2D torch tensor (image).
Args:
image (torch.Tensor): Input tensor of shape (H, W) or (B, C, H, W) or (C, H, W)
kernel_size (int): Size of the local neighborhood for maxima detection (default: 3)
threshold (float, optional): Minimum value threshold for considering pixels
return_mask (bool): If True, return both suppressed image and binary mask
Returns:
torch.Tensor: Image with non-maxima suppressed
torch.Tensor (optional): Binary mask of local maxima if return_mask=True
"""
original_shape = image.shape
# Handle different input shapes
if len(image.shape) == 2: # (H, W)
image = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
elif len(image.shape) == 3: # (C, H, W)
image = image.unsqueeze(0) # (1, C, H, W)
elif len(image.shape) == 4: # (B, C, H, W)
pass
else:
raise ValueError(f"Unsupported tensor shape: {original_shape}")
batch_size, channels, height, width = image.shape
# Apply threshold if specified
if threshold is not None:
image = torch.where(image >= threshold, image, torch.tensor(0.0, device=image.device))
# Perform max pooling to find local maxima
padding = kernel_size // 2
max_pooled = F.max_pool2d(image, kernel_size=kernel_size, stride=1, padding=padding)
# Create mask where original values equal max pooled values (local maxima)
mask = (image == max_pooled) & (image > 0)
# Apply non-maxima suppression
suppressed = image * mask.float()
# Reshape back to original shape
if len(original_shape) == 2:
suppressed = suppressed.squeeze(0).squeeze(0)
mask = mask.squeeze(0).squeeze(0)
elif len(original_shape) == 3:
suppressed = suppressed.squeeze(0)
mask = mask.squeeze(0)
if return_mask:
return suppressed, mask
return suppressed
def non_maxima_suppression_with_orientation(
magnitude: torch.Tensor,
orientation: torch.Tensor,
threshold: Optional[float] = None
) -> torch.Tensor:
"""
Perform oriented non-maxima suppression (commonly used in edge detection).
Args:
magnitude (torch.Tensor): Gradient magnitude tensor of shape (H, W) or (B, C, H, W)
orientation (torch.Tensor): Gradient orientation tensor (in radians) of same shape
threshold (float, optional): Minimum magnitude threshold
Returns:
torch.Tensor: Non-maxima suppressed magnitude
"""
original_shape = magnitude.shape
# Handle different input shapes
if len(magnitude.shape) == 2:
magnitude = magnitude.unsqueeze(0).unsqueeze(0)
orientation = orientation.unsqueeze(0).unsqueeze(0)
elif len(magnitude.shape) == 3:
magnitude = magnitude.unsqueeze(0)
orientation = orientation.unsqueeze(0)
batch_size, channels, height, width = magnitude.shape
device = magnitude.device
# Apply threshold if specified
if threshold is not None:
magnitude = torch.where(magnitude >= threshold, magnitude, torch.tensor(0.0, device=device))
# Convert orientation to degrees and normalize to [0, 180)
angle = torch.rad2deg(orientation) % 180
# Create padded magnitude for neighbor comparison
mag_padded = F.pad(magnitude, (1, 1, 1, 1), mode='constant', value=0)
# Initialize output
suppressed = torch.zeros_like(magnitude)
# Define 8-connectivity neighbors
for b in range(batch_size):
for c in range(channels):
mag = magnitude[b, c]
ang = angle[b, c]
mag_pad = mag_padded[b, c]
for i in range(1, height + 1):
for j in range(1, width + 1):
current_mag = mag_pad[i, j]
current_angle = ang[i-1, j-1]
if current_mag == 0:
continue
# Determine interpolation direction based on angle
if (0 <= current_angle < 22.5) or (157.5 <= current_angle < 180):
# Horizontal direction (0°)
neighbor1 = mag_pad[i, j-1]
neighbor2 = mag_pad[i, j+1]
elif 22.5 <= current_angle < 67.5:
# Diagonal direction (45°)
neighbor1 = mag_pad[i-1, j+1]
neighbor2 = mag_pad[i+1, j-1]
elif 67.5 <= current_angle < 112.5:
# Vertical direction (90°)
neighbor1 = mag_pad[i-1, j]
neighbor2 = mag_pad[i+1, j]
else: # 112.5 <= current_angle < 157.5
# Diagonal direction (135°)
neighbor1 = mag_pad[i-1, j-1]
neighbor2 = mag_pad[i+1, j+1]
# Keep pixel if it's a local maximum
if current_mag >= neighbor1 and current_mag >= neighbor2:
suppressed[b, c, i-1, j-1] = current_mag
# Reshape back to original shape
if len(original_shape) == 2:
suppressed = suppressed.squeeze(0).squeeze(0)
elif len(original_shape) == 3:
suppressed = suppressed.squeeze(0)
return suppressed
def adaptive_non_maxima_suppression(
image: torch.Tensor,
num_points: int,
min_distance: int = 5,
threshold: Optional[float] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Adaptive non-maxima suppression that selects a fixed number of strongest points
while maintaining minimum distance between them.
Args:
image (torch.Tensor): Input tensor of shape (H, W)
num_points (int): Number of points to select
min_distance (int): Minimum distance between selected points
threshold (float, optional): Minimum value threshold
Returns:
Tuple[torch.Tensor, torch.Tensor]: Coordinates (y, x) and values of selected points
"""
if len(image.shape) != 2:
raise ValueError("Input must be a 2D tensor")
height, width = image.shape
device = image.device
# Apply threshold if specified
if threshold is not None:
image = torch.where(image >= threshold, image, torch.tensor(0.0, device=device))
# Find all local maxima using simple NMS
nms_result = non_maxima_suppression_2d(image, kernel_size=3)
# Get coordinates and values of all local maxima
y_coords, x_coords = torch.nonzero(nms_result > 0, as_tuple=True)
values = nms_result[y_coords, x_coords]
if len(values) == 0:
return torch.empty((0, 2), device=device), torch.empty(0, device=device)
# Sort by strength (descending)
sorted_indices = torch.argsort(values, descending=True)
y_coords = y_coords[sorted_indices]
x_coords = x_coords[sorted_indices]
values = values[sorted_indices]
# Select points with minimum distance constraint
selected_coords = []
selected_values = []
for i in range(len(values)):
if len(selected_coords) >= num_points:
break
current_y, current_x = y_coords[i].item(), x_coords[i].item()
current_val = values[i].item()
# Check distance to all previously selected points
valid = True
for sel_y, sel_x in selected_coords:
distance = ((current_y - sel_y) ** 2 + (current_x - sel_x) ** 2) ** 0.5
if distance < min_distance:
valid = False
break
if valid:
selected_coords.append((current_y, current_x))
selected_values.append(current_val)
if selected_coords:
coords_tensor = torch.tensor(selected_coords, device=device, dtype=torch.float32)
values_tensor = torch.tensor(selected_values, device=device, dtype=torch.float32)
else:
coords_tensor = torch.empty((0, 2), device=device)
values_tensor = torch.empty(0, device=device)
return coords_tensor, values_tensor
|