| from torch import nn | |
| from typing import Tuple | |
| import torch | |
| class PatchEmbedding(nn.Module): | |
| """Transform channel matrix into sequence | |
| Extracts non-overlapping 2D regions from the matrix, flattens them | |
| and outputs a sequence of flattened vectors in row-major order. | |
| """ | |
| def __init__(self, patch_size: Tuple[int, int] = (10, 4)): | |
| """Initialize the PatchEmbedding layer. | |
| Args: | |
| patch_size: Size of patches to extract (subcarriers_per_patch, symbols_per_patch) | |
| """ | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Transform input tensor into patch embeddings. | |
| Args: | |
| x: Input tensor of shape (batch_size, num_subcarriers, num_symbols) | |
| Returns: | |
| Tensor of shape (batch_size, num_patches, patch_size[0]*patch_size[1]) | |
| where num_patches = (num_subcarriers // patch_size[0]) * (num_symbols // patch_size[1]) | |
| """ | |
| x = self.unfold(torch.unsqueeze(x, dim=1)) | |
| return torch.permute(x, dims=(0, 2, 1)) | |
| class InversePatchEmbedding(nn.Module): | |
| """Transform patch embeddings back to original matrix format.""" | |
| def __init__( | |
| self, | |
| output_size: Tuple[int, int] = (120, 14), | |
| patch_size: Tuple[int, int] = (3, 2) | |
| ): | |
| """Initialize the InversePatchEmbedding layer. | |
| Args: | |
| output_size: Size of output matrix (num_subcarriers, num_symbols) | |
| patch_size: Size of input patches (subcarriers_per_patch, symbols_per_patch) | |
| """ | |
| super().__init__() | |
| self.fold = nn.Fold( | |
| output_size=output_size, | |
| kernel_size=patch_size, | |
| stride=patch_size | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Transform patch embeddings back to matrix format. | |
| Args: | |
| x: Input tensor of shape (batch_size, num_patches, patch_size[0]*patch_size[1]) | |
| where num_patches = (output_size[0] // patch_size[0]) * (output_size[1] // patch_size[1]) | |
| Returns: | |
| Tensor of shape (batch_size, num_subcarriers, num_symbols) | |
| """ | |
| x = torch.permute(x, dims=(0, 2, 1)) | |
| x = self.fold(x) | |
| return torch.squeeze(x, dim=1) |