|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from collections import namedtuple | 
					
						
						|  | from typing import NamedTuple, Optional, Tuple | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _run_kernel(x: torch.Tensor, mean: torch.Tensor, tx: torch.Tensor): | 
					
						
						|  | if x.ndim <= 3: | 
					
						
						|  | x = x - mean | 
					
						
						|  | x = x @ tx.T | 
					
						
						|  | elif x.ndim == 4: | 
					
						
						|  | x = x - mean.reshape(1, -1, 1, 1) | 
					
						
						|  | kernel = tx.reshape(*tx.shape, 1, 1) | 
					
						
						|  | x = torch.nn.functional.conv2d(x, weight=kernel, bias=None, stride=1, padding=0) | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f'Unsupported input dimension: {x.ndim}, shape: {x.shape}') | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FeatureNormalizer(nn.Module): | 
					
						
						|  | def __init__(self, embed_dim: int, dtype: torch.dtype = torch.float32): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.register_buffer('mean', torch.zeros(embed_dim, dtype=dtype)) | 
					
						
						|  | self.register_buffer('tx', torch.eye(embed_dim, dtype=dtype)) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | x = _run_kernel(x, self.mean, self.tx) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class InterFeatState(NamedTuple): | 
					
						
						|  | y: torch.Tensor | 
					
						
						|  | alpha: torch.Tensor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class IntermediateFeatureNormalizerBase(nn.Module): | 
					
						
						|  | def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState: | 
					
						
						|  | raise NotImplementedError() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class IntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase): | 
					
						
						|  | def __init__(self, num_intermediates: int, embed_dim: int, rot_per_layer: bool = False, dtype: torch.dtype = torch.float32): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.register_buffer('alphas', torch.ones(num_intermediates, dtype=dtype)) | 
					
						
						|  |  | 
					
						
						|  | rot = torch.eye(embed_dim, dtype=dtype) | 
					
						
						|  | if rot_per_layer: | 
					
						
						|  | rot = rot.unsqueeze(0).repeat(num_intermediates, 1, 1) | 
					
						
						|  |  | 
					
						
						|  | self.register_buffer('rotation', rot.contiguous()) | 
					
						
						|  | self.register_buffer('means', torch.zeros(num_intermediates, embed_dim, dtype=dtype)) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState: | 
					
						
						|  | if rot_index is None: | 
					
						
						|  | rot_index = index | 
					
						
						|  |  | 
					
						
						|  | if skip: | 
					
						
						|  | assert x.ndim == 3, f'Cannot use the `skip` parameter when the `x` tensor isn\'t 3-dimensional.' | 
					
						
						|  | prefix, x = x[:, :skip], x[:, skip:] | 
					
						
						|  |  | 
					
						
						|  | rotation = self._get_rotation(rot_index) | 
					
						
						|  | y = _run_kernel(x, self.means[index], rotation) | 
					
						
						|  |  | 
					
						
						|  | alpha = self.alphas[index] | 
					
						
						|  | if skip: | 
					
						
						|  | alpha = torch.cat([ | 
					
						
						|  | torch.ones(skip, dtype=alpha.dtype, device=alpha.device), | 
					
						
						|  | alpha[None].expand(y.shape[1]), | 
					
						
						|  | ]).reshape(1, -1, 1) | 
					
						
						|  | y = torch.cat([prefix, y], dim=1) | 
					
						
						|  | else: | 
					
						
						|  | if x.ndim == 3: | 
					
						
						|  | alpha = alpha.reshape(1, 1, 1).expand(1, y.shape[1], 1) | 
					
						
						|  | elif x.ndim == 4: | 
					
						
						|  | alpha = alpha.reshape(1, 1, 1, 1).expand(1, 1, *y.shape[2:]) | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f'Unsupported input dimension: {x.ndim}') | 
					
						
						|  |  | 
					
						
						|  | return InterFeatState(y, alpha) | 
					
						
						|  |  | 
					
						
						|  | def _get_rotation(self, rot_index: int) -> torch.Tensor: | 
					
						
						|  | if self.rotation.ndim == 2: | 
					
						
						|  | return self.rotation | 
					
						
						|  | return self.rotation[rot_index] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class NullIntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase): | 
					
						
						|  | instances = dict() | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, dtype: torch.dtype, device: torch.device): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.register_buffer('alpha', torch.tensor(1, dtype=dtype, device=device)) | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def get_instance(dtype: torch.dtype, device: torch.device): | 
					
						
						|  | instance = NullIntermediateFeatureNormalizer.instances.get((dtype, device), None) | 
					
						
						|  | if instance is None: | 
					
						
						|  | instance = NullIntermediateFeatureNormalizer(dtype, device) | 
					
						
						|  | NullIntermediateFeatureNormalizer.instances[(dtype, device)] = instance | 
					
						
						|  | return instance | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState: | 
					
						
						|  | return InterFeatState(x, self.alpha) | 
					
						
						|  |  |