File size: 2,970 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Image to Patch Embedding using Conv2d.

Modified from vision_transformer
(https://github.com/google-research/vision_transformer).
"""

from __future__ import annotations

import torch
from torch import nn


class PatchEmbed(nn.Module):
    """2D Image to Patch Embedding."""

    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        in_channels: int = 3,
        embed_dim: int = 768,
        norm_layer: nn.Module | None = None,
        flatten: bool = True,
        bias: bool = True,
    ):
        """Init PatchEmbed.

        Args:
            img_size (int, optional): Input image's size. Defaults to 224.
            patch_size (int, optional): Patch size. Defaults to 16.
            in_channels (int, optional): Number of input image's channels.
                Defaults to 3.
            embed_dim (int, optional): Patch embedding's dim. Defaults to 768.
            norm_layer (nn.Module, optional): Normalization layer. Defaults to
                None, which means no normalization layer.
            flatten (bool, optional): If to flatten the output tensor.
                Defaults to True.
            bias (bool, optional): If to add bias to the convolution layer.
                Defaults to True.

        Raises:
            ValueError: If the input image's size is not divisible by the patch
                size.
        """
        super().__init__()
        self.img_size = (img_size, img_size)
        self.patch_size = (patch_size, patch_size)
        self.grid_size = (
            self.img_size[0] // self.patch_size[0],
            self.img_size[1] // self.patch_size[1],
        )
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten

        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
            bias=bias,
        )
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def __call__(self, data: torch.Tensor) -> torch.Tensor:
        """Applies the layer.

        Args:
            data (torch.Tensor): Input tensor of shape (B, C, H, W).

        Returns:
            torch.Tensor: Output tensor of shape (B, N, C), where N is the
                number of patches (N = H * W).
        """
        return self._call_impl(data)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward function."""
        _, _, height, width = x.shape
        assert height == self.img_size[0], (
            f"Input image height ({height}) doesn't match model"
            f"({self.img_size})."
        )
        assert width == self.img_size[1], (
            f"Input image width ({width}) doesn't match model"
            f"({self.img_size})."
        )
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # (B, C, H, W) -> (B, N, C)
        x = self.norm(x)
        return x