File size: 3,106 Bytes
9fd1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from typing import Any, Dict, List, Optional, Union

import numpy as np
import PIL.Image
import torch

from ..utils.import_utils import is_kornia_available
from .base import ProcessorMixin


if is_kornia_available():
    import kornia


class CannyProcessor(ProcessorMixin):
    r"""
    Processor for obtaining the Canny edge detection of an image.

    Args:
        output_names (`List[str]`):
            The names of the outputs that the processor should return. The first output is the Canny edge detection of
            the input image.
    """

    def __init__(
        self,
        output_names: List[str] = None,
        input_names: Optional[Dict[str, Any]] = None,
        device: Optional[torch.device] = None,
    ):
        super().__init__()

        self.output_names = output_names
        self.input_names = input_names
        self.device = device
        assert len(output_names) == 1

    def forward(self, input: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]]) -> torch.Tensor:
        r"""
        Obtain the Canny edge detection of the input image.

        Args:
            input (`torch.Tensor`, `PIL.Image.Image`, or `List[PIL.Image.Image]`):
                The input tensor, image or list of images for which the Canny edge detection should be obtained.
                If a tensor, must be a 3D (CHW) or 4D (BCHW) or 5D (BTCHW) tensor. The input tensor should have
                values in the range [0, 1].

        Returns:
            torch.Tensor:
                The Canny edge detection of the input image. The output has the same shape as the input tensor. If
                the input is an image, the output is a 3D tensor. If the input is a list of images, the output is a 5D
                tensor. The output tensor has values in the range [0, 1].
        """
        if isinstance(input, PIL.Image.Image):
            input = kornia.utils.image.image_to_tensor(np.array(input)).unsqueeze(0) / 255.0
            input = input.to(self.device)
            output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1).squeeze(0)
        elif isinstance(input, list):
            input = kornia.utils.image.image_list_to_tensor([np.array(img) for img in input]) / 255.0
            output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1)
        else:
            ndim = input.ndim
            assert ndim in [3, 4, 5]

            batch_size = 1 if ndim == 3 else input.size(0)

            if ndim == 3:
                input = input.unsqueeze(0)  # [C, H, W] -> [1, C, H, W]
            elif ndim == 5:
                input = input.flatten(0, 1)  # [B, F, C, H, W] -> [B*F, C, H, W]

            output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1)
            output = output[0] if ndim == 3 else output.unflatten(0, (batch_size, -1)) if ndim == 5 else output

        # TODO(aryan): think about how one can pass parameters to the underlying function from
        # a UI perspective. It's important to think about ProcessorMixin in terms of a Graph-based
        # data processing pipeline.
        return {self.output_names[0]: output}