| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import math |
| | from typing import Dict, Optional |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| | from einops import rearrange |
| | from timm.models.vision_transformer import Block |
| |
|
| | from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam |
| | from .adaptor_base import AdaptorModuleBase |
| | from .adaptor_mlp import MLP2 |
| |
|
| |
|
| | class AttnFDHead(AdaptorModuleBase): |
| | def __init__( |
| | self, |
| | input_size: int, |
| | hidden_size: int, |
| | output_size: int, |
| | num_inner: int = 0, |
| | pre_norm: bool = False, |
| | device: torch.device = None, |
| | upsample_factor: int = 1, |
| | upsample_rank: int = 0, |
| | **kwargs |
| | ) -> None: |
| | super().__init__(requires_summary_and_spatial=False) |
| | from timm.models.vision_transformer import Block |
| | self.blocks = nn.Sequential(*[ |
| | Block(input_size, num_heads=16, init_values=1e-5) |
| | for _ in range(2) |
| | ]) |
| | self.mlp = MLP2(input_size, hidden_size, output_size, |
| | num_inner=0, pre_norm=pre_norm, device=device, |
| | upsample_factor=upsample_factor, upsample_rank=upsample_rank, **kwargs) |
| |
|
| | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: |
| | x = self.blocks(x) |
| | x = self.mlp(x) |
| | return x |
| |
|