|  | from typing import Any, Dict, List, Optional, Union | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  | import PIL.Image | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | from ..utils.import_utils import is_kornia_available | 
					
						
						|  | from .base import ProcessorMixin | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if is_kornia_available(): | 
					
						
						|  | import kornia | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CannyProcessor(ProcessorMixin): | 
					
						
						|  | r""" | 
					
						
						|  | Processor for obtaining the Canny edge detection of an image. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | output_names (`List[str]`): | 
					
						
						|  | The names of the outputs that the processor should return. The first output is the Canny edge detection of | 
					
						
						|  | the input image. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | output_names: List[str] = None, | 
					
						
						|  | input_names: Optional[Dict[str, Any]] = None, | 
					
						
						|  | device: Optional[torch.device] = None, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.output_names = output_names | 
					
						
						|  | self.input_names = input_names | 
					
						
						|  | self.device = device | 
					
						
						|  | assert len(output_names) == 1 | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]]) -> torch.Tensor: | 
					
						
						|  | r""" | 
					
						
						|  | Obtain the Canny edge detection of the input image. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | input (`torch.Tensor`, `PIL.Image.Image`, or `List[PIL.Image.Image]`): | 
					
						
						|  | The input tensor, image or list of images for which the Canny edge detection should be obtained. | 
					
						
						|  | If a tensor, must be a 3D (CHW) or 4D (BCHW) or 5D (BTCHW) tensor. The input tensor should have | 
					
						
						|  | values in the range [0, 1]. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | torch.Tensor: | 
					
						
						|  | The Canny edge detection of the input image. The output has the same shape as the input tensor. If | 
					
						
						|  | the input is an image, the output is a 3D tensor. If the input is a list of images, the output is a 5D | 
					
						
						|  | tensor. The output tensor has values in the range [0, 1]. | 
					
						
						|  | """ | 
					
						
						|  | if isinstance(input, PIL.Image.Image): | 
					
						
						|  | input = kornia.utils.image.image_to_tensor(np.array(input)).unsqueeze(0) / 255.0 | 
					
						
						|  | input = input.to(self.device) | 
					
						
						|  | output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1).squeeze(0) | 
					
						
						|  | elif isinstance(input, list): | 
					
						
						|  | input = kornia.utils.image.image_list_to_tensor([np.array(img) for img in input]) / 255.0 | 
					
						
						|  | output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1) | 
					
						
						|  | else: | 
					
						
						|  | ndim = input.ndim | 
					
						
						|  | assert ndim in [3, 4, 5] | 
					
						
						|  |  | 
					
						
						|  | batch_size = 1 if ndim == 3 else input.size(0) | 
					
						
						|  |  | 
					
						
						|  | if ndim == 3: | 
					
						
						|  | input = input.unsqueeze(0) | 
					
						
						|  | elif ndim == 5: | 
					
						
						|  | input = input.flatten(0, 1) | 
					
						
						|  |  | 
					
						
						|  | output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1) | 
					
						
						|  | output = output[0] if ndim == 3 else output.unflatten(0, (batch_size, -1)) if ndim == 5 else output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return {self.output_names[0]: output} | 
					
						
						|  |  |