File size: 4,366 Bytes
b20c769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from pathlib import Path
from typing import List

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class DeCurWrapper(nn.Module):
    def __init__(
        self, weights_path: Path, modality: str, do_pool=True, temporal_pooling: str = "mean"
    ):
        super().__init__()
        assert modality in ["SAR", "optical"]

        self.encoder = timm.create_model("vit_small_patch16_224", pretrained=False)
        self.dim = 384
        self.modality = modality
        if modality == "optical":
            self.encoder.patch_embed.proj = torch.nn.Conv2d(
                13, 384, kernel_size=(16, 16), stride=(16, 16)
            )
            state_dict = torch.load(weights_path / "vits16_ssl4eo-s12_ms_decur_ep100.pth")
            msg = self.encoder.load_state_dict(state_dict, strict=False)
            assert set(msg.missing_keys) == {"head.weight", "head.bias"}
        else:
            self.encoder.patch_embed.proj = torch.nn.Conv2d(
                2, 384, kernel_size=(16, 16), stride=(16, 16)
            )
            state_dict = torch.load(weights_path / "vits16_ssl4eo-s12_sar_decur_ep100.pth")
            msg = self.encoder.load_state_dict(state_dict, strict=False)
            assert set(msg.missing_keys) == {"head.weight", "head.bias"}

        self.image_resolution = 224
        self.patch_size = 16
        self.grid_size = int(self.image_resolution / self.patch_size)
        self.do_pool = do_pool
        if temporal_pooling not in ["mean", "max"]:
            raise ValueError(
                f"Expected temporal_pooling to be in ['mean', 'max'], got {temporal_pooling}"
            )
        self.temporal_pooling = temporal_pooling

    def resize(self, images):
        images = F.interpolate(
            images,
            size=(self.image_resolution, self.image_resolution),
            mode="bilinear",
            align_corners=False,
        )
        return images

    def preproccess(self, images):
        images = rearrange(images, "b h w c -> b c h w")
        assert (images.shape[1] == 13) or (images.shape[1] == 2)
        return self.resize(images)  # (bsz, C, H, W)

    def forward(self, s2=None, s1=None, months=None):
        if s1 is not None:
            assert self.modality == "SAR"
            if len(s1.shape) == 5:
                outputs_l: List[torch.Tensor] = []
                for timestep in range(s1.shape[3]):
                    image = self.preproccess(s1[:, :, :, timestep])
                    output = self.encoder.forward_features(image)
                    if self.do_pool:
                        output = output.mean(dim=1)
                    else:
                        output = output[:, 1:]
                    outputs_l.append(output)
                outputs_t = torch.stack(outputs_l, dim=-1)  # b h w d t
                if self.temporal_pooling == "mean":
                    return outputs_t.mean(dim=-1)
                else:
                    return torch.amax(outputs_t, dim=-1)
            else:
                s1 = self.preproccess(s1)
                output = self.encoder.forward_features(s1)
                if self.do_pool:
                    return output.mean(dim=1)
                else:
                    return output[:, 1:]
        elif s2 is not None:
            assert self.modality == "optical"
            if len(s2.shape) == 5:
                outputs_l: List[torch.Tensor] = []
                for timestep in range(s2.shape[3]):
                    image = self.preproccess(s2[:, :, :, timestep])
                    output = self.encoder.forward_features(image)
                    if self.do_pool:
                        output = output.mean(dim=1)
                    else:
                        output = output[:, 1:]
                    outputs_l.append(output)
                outputs_t = torch.stack(outputs_l, dim=-1)  # b h w d t
                if self.temporal_pooling == "mean":
                    return outputs_t.mean(dim=-1)
                else:
                    return torch.amax(outputs_t, dim=-1)
            else:
                s2 = self.preproccess(s2)
                output = self.encoder.forward_features(s2)
                if self.do_pool:
                    return output.mean(dim=1)
                else:
                    return output[:, 1:]