AdaFortiTran / src /models /blocks /patch_processors.py
BerkIGuler's picture
fixes on src/models
687eaba
from torch import nn
from typing import Tuple
import torch
class PatchEmbedding(nn.Module):
"""Transform channel matrix into sequence
Extracts non-overlapping 2D regions from the matrix, flattens them
and outputs a sequence of flattened vectors in row-major order.
"""
def __init__(self, patch_size: Tuple[int, int] = (10, 4)):
"""Initialize the PatchEmbedding layer.
Args:
patch_size: Size of patches to extract (subcarriers_per_patch, symbols_per_patch)
"""
super().__init__()
self.patch_size = patch_size
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Transform input tensor into patch embeddings.
Args:
x: Input tensor of shape (batch_size, num_subcarriers, num_symbols)
Returns:
Tensor of shape (batch_size, num_patches, patch_size[0]*patch_size[1])
where num_patches = (num_subcarriers // patch_size[0]) * (num_symbols // patch_size[1])
"""
x = self.unfold(torch.unsqueeze(x, dim=1))
return torch.permute(x, dims=(0, 2, 1))
class InversePatchEmbedding(nn.Module):
"""Transform patch embeddings back to original matrix format."""
def __init__(
self,
output_size: Tuple[int, int] = (120, 14),
patch_size: Tuple[int, int] = (3, 2)
):
"""Initialize the InversePatchEmbedding layer.
Args:
output_size: Size of output matrix (num_subcarriers, num_symbols)
patch_size: Size of input patches (subcarriers_per_patch, symbols_per_patch)
"""
super().__init__()
self.fold = nn.Fold(
output_size=output_size,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Transform patch embeddings back to matrix format.
Args:
x: Input tensor of shape (batch_size, num_patches, patch_size[0]*patch_size[1])
where num_patches = (output_size[0] // patch_size[0]) * (output_size[1] // patch_size[1])
Returns:
Tensor of shape (batch_size, num_subcarriers, num_symbols)
"""
x = torch.permute(x, dims=(0, 2, 1))
x = self.fold(x)
return torch.squeeze(x, dim=1)