| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from argparse import Namespace |
| | from typing import NamedTuple, Optional |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class AdaptorInput(NamedTuple): |
| | images: torch.Tensor |
| | summary: torch.Tensor |
| | features: torch.Tensor |
| | feature_fmt: str |
| | patch_size: int |
| |
|
| |
|
| | class RadioOutput(NamedTuple): |
| | summary: torch.Tensor |
| | features: torch.Tensor |
| |
|
| | def to(self, *args, **kwargs): |
| | return RadioOutput( |
| | self.summary.to(*args, **kwargs) if self.summary is not None else None, |
| | self.features.to(*args, **kwargs) if self.features is not None else None, |
| | ) |
| |
|
| |
|
| | class AdaptorBase(nn.Module): |
| | def forward(self, input: AdaptorInput) -> RadioOutput: |
| | raise NotImplementedError("Subclasses must implement this!") |
| |
|