File size: 11,493 Bytes
a3f0d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import torch
import numpy as np
import PIL.Image as Image
import torchvision.transforms as transforms
import torch.nn.functional as F
from typing import Optional, Tuple, Union


def morphological_open(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor:
    """
    Perform morphological opening on a 2D torch tensor (image).

    Args: 
        image (torch.Tensor): image to open
        kernel_size (int): size of the structuring element - roughly the size of hole to be opened

    Returns: 
        torch.Tensor: The opened image.
    """
    kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device)
    eroded = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2)
    eroded = (eroded > 0).float()
    dilated = F.conv2d(eroded, kernel, stride=1, padding=kernel_size // 2)
    return (dilated > 0).float()

def morphological_close(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor:
    """
    Perform morphological closing on a 2D torch tensor (image).

    Args: 
        image (torch.Tensor): image to close
        kernel_size (int): size of the structuring element - roughly the size of hole to be closed
    
    Returns: 
        torch.Tensor: The closed image.
    """
    kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device)
    dilated = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2)
    dilated = (dilated > 0).float()
    eroded = F.conv2d(dilated, kernel, stride=1, padding=kernel_size // 2)
    return (eroded > 0).float()

def gaussian_convolve(image: torch.Tensor, kernel_size: int = 5, sigma: float = 1.0) -> torch.Tensor:
    """
    Gaussian Convolution to smooth image

    Args:
        image (torch.Tensor): image to convolve
        kernel_size (int): size of the Gaussian kernel
        sigma (float): standard deviation of the Gaussian distribution

    Returns: 
        torch.Tensor: The convolved image.
    """
    x = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32)
    y = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32)
    x, y = torch.meshgrid(x, y)
    kernel = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
    kernel = kernel / kernel.sum()
    # Apply the Gaussian kernel
    return F.conv2d(image.unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0), stride=1, padding=kernel_size // 2)

def hysteresis_filter(image: torch.Tensor, low_threshold: float, high_threshold: float) -> torch.Tensor:
    """
    Hysteresis Filter Function - for Canny Edge detection
    
    Args: 
        image (torch.Tensor): image to process
        low_threshold (float): low threshold for hysteresis
        high_threshold (float): high threshold for hysteresis

    Returns: 
        edge (torch.Tensor): The edges detected in the image.   
    
    """
    edges = (image > high_threshold).float()
    # Perform hysteresis thresholding
    edges = torch.where(image > low_threshold, edges, 0)
    return edges    

def non_maxima_suppression_2d(
    image: torch.Tensor,
    kernel_size: int = 3,
    threshold: Optional[float] = None,
    return_mask: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Perform non-maxima suppression on a 2D torch tensor (image).
    
    Args:
        image (torch.Tensor): Input tensor of shape (H, W) or (B, C, H, W) or (C, H, W)
        kernel_size (int): Size of the local neighborhood for maxima detection (default: 3)
        threshold (float, optional): Minimum value threshold for considering pixels
        return_mask (bool): If True, return both suppressed image and binary mask
    
    Returns:
        torch.Tensor: Image with non-maxima suppressed
        torch.Tensor (optional): Binary mask of local maxima if return_mask=True
    """
    original_shape = image.shape
    
    # Handle different input shapes
    if len(image.shape) == 2:  # (H, W)
        image = image.unsqueeze(0).unsqueeze(0)  # (1, 1, H, W)
    elif len(image.shape) == 3:  # (C, H, W)
        image = image.unsqueeze(0)  # (1, C, H, W)
    elif len(image.shape) == 4:  # (B, C, H, W)
        pass
    else:
        raise ValueError(f"Unsupported tensor shape: {original_shape}")
    
    batch_size, channels, height, width = image.shape
    
    # Apply threshold if specified
    if threshold is not None:
        image = torch.where(image >= threshold, image, torch.tensor(0.0, device=image.device))
    
    # Perform max pooling to find local maxima
    padding = kernel_size // 2
    max_pooled = F.max_pool2d(image, kernel_size=kernel_size, stride=1, padding=padding)
    
    # Create mask where original values equal max pooled values (local maxima)
    mask = (image == max_pooled) & (image > 0)
    
    # Apply non-maxima suppression
    suppressed = image * mask.float()
    
    # Reshape back to original shape
    if len(original_shape) == 2:
        suppressed = suppressed.squeeze(0).squeeze(0)
        mask = mask.squeeze(0).squeeze(0)
    elif len(original_shape) == 3:
        suppressed = suppressed.squeeze(0)
        mask = mask.squeeze(0)
    
    if return_mask:
        return suppressed, mask
    return suppressed


def non_maxima_suppression_with_orientation(
    magnitude: torch.Tensor,
    orientation: torch.Tensor,
    threshold: Optional[float] = None
) -> torch.Tensor:
    """
    Perform oriented non-maxima suppression (commonly used in edge detection).
    
    Args:
        magnitude (torch.Tensor): Gradient magnitude tensor of shape (H, W) or (B, C, H, W)
        orientation (torch.Tensor): Gradient orientation tensor (in radians) of same shape
        threshold (float, optional): Minimum magnitude threshold
    
    Returns:
        torch.Tensor: Non-maxima suppressed magnitude
    """
    original_shape = magnitude.shape
    
    # Handle different input shapes
    if len(magnitude.shape) == 2:
        magnitude = magnitude.unsqueeze(0).unsqueeze(0)
        orientation = orientation.unsqueeze(0).unsqueeze(0)
    elif len(magnitude.shape) == 3:
        magnitude = magnitude.unsqueeze(0)
        orientation = orientation.unsqueeze(0)
    
    batch_size, channels, height, width = magnitude.shape
    device = magnitude.device
    
    # Apply threshold if specified
    if threshold is not None:
        magnitude = torch.where(magnitude >= threshold, magnitude, torch.tensor(0.0, device=device))
    
    # Convert orientation to degrees and normalize to [0, 180)
    angle = torch.rad2deg(orientation) % 180
    
    # Create padded magnitude for neighbor comparison
    mag_padded = F.pad(magnitude, (1, 1, 1, 1), mode='constant', value=0)
    
    # Initialize output
    suppressed = torch.zeros_like(magnitude)
    
    # Define 8-connectivity neighbors
    for b in range(batch_size):
        for c in range(channels):
            mag = magnitude[b, c]
            ang = angle[b, c]
            mag_pad = mag_padded[b, c]
            
            for i in range(1, height + 1):
                for j in range(1, width + 1):
                    current_mag = mag_pad[i, j]
                    current_angle = ang[i-1, j-1]
                    
                    if current_mag == 0:
                        continue
                    
                    # Determine interpolation direction based on angle
                    if (0 <= current_angle < 22.5) or (157.5 <= current_angle < 180):
                        # Horizontal direction (0°)
                        neighbor1 = mag_pad[i, j-1]
                        neighbor2 = mag_pad[i, j+1]
                    elif 22.5 <= current_angle < 67.5:
                        # Diagonal direction (45°)
                        neighbor1 = mag_pad[i-1, j+1]
                        neighbor2 = mag_pad[i+1, j-1]
                    elif 67.5 <= current_angle < 112.5:
                        # Vertical direction (90°)
                        neighbor1 = mag_pad[i-1, j]
                        neighbor2 = mag_pad[i+1, j]
                    else:  # 112.5 <= current_angle < 157.5
                        # Diagonal direction (135°)
                        neighbor1 = mag_pad[i-1, j-1]
                        neighbor2 = mag_pad[i+1, j+1]
                    
                    # Keep pixel if it's a local maximum
                    if current_mag >= neighbor1 and current_mag >= neighbor2:
                        suppressed[b, c, i-1, j-1] = current_mag
    
    # Reshape back to original shape
    if len(original_shape) == 2:
        suppressed = suppressed.squeeze(0).squeeze(0)
    elif len(original_shape) == 3:
        suppressed = suppressed.squeeze(0)
    
    return suppressed


def adaptive_non_maxima_suppression(
    image: torch.Tensor,
    num_points: int,
    min_distance: int = 5,
    threshold: Optional[float] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Adaptive non-maxima suppression that selects a fixed number of strongest points
    while maintaining minimum distance between them.
    
    Args:
        image (torch.Tensor): Input tensor of shape (H, W)
        num_points (int): Number of points to select
        min_distance (int): Minimum distance between selected points
        threshold (float, optional): Minimum value threshold
    
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Coordinates (y, x) and values of selected points
    """
    if len(image.shape) != 2:
        raise ValueError("Input must be a 2D tensor")
    
    height, width = image.shape
    device = image.device
    
    # Apply threshold if specified
    if threshold is not None:
        image = torch.where(image >= threshold, image, torch.tensor(0.0, device=device))
    
    # Find all local maxima using simple NMS
    nms_result = non_maxima_suppression_2d(image, kernel_size=3)
    
    # Get coordinates and values of all local maxima
    y_coords, x_coords = torch.nonzero(nms_result > 0, as_tuple=True)
    values = nms_result[y_coords, x_coords]
    
    if len(values) == 0:
        return torch.empty((0, 2), device=device), torch.empty(0, device=device)
    
    # Sort by strength (descending)
    sorted_indices = torch.argsort(values, descending=True)
    y_coords = y_coords[sorted_indices]
    x_coords = x_coords[sorted_indices]
    values = values[sorted_indices]
    
    # Select points with minimum distance constraint
    selected_coords = []
    selected_values = []
    
    for i in range(len(values)):
        if len(selected_coords) >= num_points:
            break
        
        current_y, current_x = y_coords[i].item(), x_coords[i].item()
        current_val = values[i].item()
        
        # Check distance to all previously selected points
        valid = True
        for sel_y, sel_x in selected_coords:
            distance = ((current_y - sel_y) ** 2 + (current_x - sel_x) ** 2) ** 0.5
            if distance < min_distance:
                valid = False
                break
        
        if valid:
            selected_coords.append((current_y, current_x))
            selected_values.append(current_val)
    
    if selected_coords:
        coords_tensor = torch.tensor(selected_coords, device=device, dtype=torch.float32)
        values_tensor = torch.tensor(selected_values, device=device, dtype=torch.float32)
    else:
        coords_tensor = torch.empty((0, 2), device=device)
        values_tensor = torch.empty(0, device=device)
    
    return coords_tensor, values_tensor