|
|
|
from typing import Dict |
|
|
|
import timm.models.vision_transformer as vit |
|
import torch |
|
|
|
|
|
def build_imagenet_baselines() -> Dict[str, torch.jit.ScriptModule]: |
|
"""This returns the prepped imagenet encoders from timm, not bad for microscopy data.""" |
|
vit_backbones = [ |
|
_make_vit(vit.vit_small_patch16_384), |
|
_make_vit(vit.vit_base_patch16_384), |
|
_make_vit(vit.vit_base_patch8_224), |
|
_make_vit(vit.vit_large_patch16_384), |
|
] |
|
model_names = [ |
|
"vit_small_patch16_384", |
|
"vit_base_patch16_384", |
|
"vit_base_patch8_224", |
|
"vit_large_patch16_384", |
|
] |
|
imagenet_encoders = list(map(_make_torchscripted_encoder, vit_backbones)) |
|
return {name: model for name, model in zip(model_names, imagenet_encoders)} |
|
|
|
|
|
def _make_torchscripted_encoder(vit_backbone) -> torch.jit.ScriptModule: |
|
dummy_input = torch.testing.make_tensor( |
|
(2, 6, 256, 256), |
|
low=0, |
|
high=255, |
|
dtype=torch.uint8, |
|
device=torch.device("cpu"), |
|
) |
|
encoder = torch.nn.Sequential( |
|
Normalizer(), |
|
torch.nn.LazyInstanceNorm2d( |
|
affine=False, track_running_stats=False |
|
), |
|
vit_backbone, |
|
).to(device="cpu") |
|
_ = encoder(dummy_input) |
|
return torch.jit.freeze(torch.jit.script(encoder.eval())) |
|
|
|
|
|
def _make_vit(constructor): |
|
return constructor( |
|
pretrained=True, |
|
img_size=256, |
|
in_chans=6, |
|
num_classes=0, |
|
fc_norm=None, |
|
class_token=True, |
|
global_pool="avg", |
|
) |
|
|
|
|
|
class Normalizer(torch.nn.Module): |
|
def forward(self, pixels: torch.Tensor) -> torch.Tensor: |
|
pixels = pixels.float() |
|
pixels /= 255.0 |
|
return pixels |
|
|