File size: 1,969 Bytes
6ded986
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# © Recursion Pharmaceuticals 2024
from typing import Dict

import timm.models.vision_transformer as vit
import torch


def build_imagenet_baselines() -> Dict[str, torch.jit.ScriptModule]:
    """This returns the prepped imagenet encoders from timm, not bad for microscopy data."""
    vit_backbones = [
        _make_vit(vit.vit_small_patch16_384),
        _make_vit(vit.vit_base_patch16_384),
        _make_vit(vit.vit_base_patch8_224),
        _make_vit(vit.vit_large_patch16_384),
    ]
    model_names = [
        "vit_small_patch16_384",
        "vit_base_patch16_384",
        "vit_base_patch8_224",
        "vit_large_patch16_384",
    ]
    imagenet_encoders = list(map(_make_torchscripted_encoder, vit_backbones))
    return {name: model for name, model in zip(model_names, imagenet_encoders)}


def _make_torchscripted_encoder(vit_backbone) -> torch.jit.ScriptModule:
    dummy_input = torch.testing.make_tensor(
        (2, 6, 256, 256),
        low=0,
        high=255,
        dtype=torch.uint8,
        device=torch.device("cpu"),
    )
    encoder = torch.nn.Sequential(
        Normalizer(),
        torch.nn.LazyInstanceNorm2d(
            affine=False, track_running_stats=False
        ),  # this module performs self-standardization, very important
        vit_backbone,
    ).to(device="cpu")
    _ = encoder(dummy_input)  # get those lazy modules built
    return torch.jit.freeze(torch.jit.script(encoder.eval()))


def _make_vit(constructor):
    return constructor(
        pretrained=True,  # download imagenet weights
        img_size=256,  # 256x256 crops
        in_chans=6,  # we expect 6-channel microscopy images
        num_classes=0,
        fc_norm=None,
        class_token=True,
        global_pool="avg",  # minimal perf diff btwn "cls" and "avg"
    )


class Normalizer(torch.nn.Module):
    def forward(self, pixels: torch.Tensor) -> torch.Tensor:
        pixels = pixels.float()
        pixels /= 255.0
        return pixels