# © Recursion Pharmaceuticals 2024 from typing import Dict import timm.models.vision_transformer as vit import torch def build_imagenet_baselines() -> Dict[str, torch.jit.ScriptModule]: """This returns the prepped imagenet encoders from timm, not bad for microscopy data.""" vit_backbones = [ _make_vit(vit.vit_small_patch16_384), _make_vit(vit.vit_base_patch16_384), _make_vit(vit.vit_base_patch8_224), _make_vit(vit.vit_large_patch16_384), ] model_names = [ "vit_small_patch16_384", "vit_base_patch16_384", "vit_base_patch8_224", "vit_large_patch16_384", ] imagenet_encoders = list(map(_make_torchscripted_encoder, vit_backbones)) return {name: model for name, model in zip(model_names, imagenet_encoders)} def _make_torchscripted_encoder(vit_backbone) -> torch.jit.ScriptModule: dummy_input = torch.testing.make_tensor( (2, 6, 256, 256), low=0, high=255, dtype=torch.uint8, device=torch.device("cpu"), ) encoder = torch.nn.Sequential( Normalizer(), torch.nn.LazyInstanceNorm2d( affine=False, track_running_stats=False ), # this module performs self-standardization, very important vit_backbone, ).to(device="cpu") _ = encoder(dummy_input) # get those lazy modules built return torch.jit.freeze(torch.jit.script(encoder.eval())) def _make_vit(constructor): return constructor( pretrained=True, # download imagenet weights img_size=256, # 256x256 crops in_chans=6, # we expect 6-channel microscopy images num_classes=0, fc_norm=None, class_token=True, global_pool="avg", # minimal perf diff btwn "cls" and "avg" ) class Normalizer(torch.nn.Module): def forward(self, pixels: torch.Tensor) -> torch.Tensor: pixels = pixels.float() pixels /= 255.0 return pixels