|
from typing import Dict, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
from loss import FourierLoss |
|
from normalizer import Normalizer |
|
from mae_modules import CAMAEDecoder, MAEDecoder, MAEEncoder |
|
from mae_utils import flatten_images |
|
from vit import ( |
|
generate_2d_sincos_pos_embeddings, |
|
sincos_positional_encoding_vit, |
|
vit_small_patch16_256, |
|
) |
|
|
|
TensorDict = Dict[str, torch.Tensor] |
|
|
|
|
|
class MAEConfig(PretrainedConfig): |
|
model_type = "MAE" |
|
|
|
def __init__( |
|
self, |
|
mask_ratio=0.75, |
|
encoder=None, |
|
decoder=None, |
|
loss=None, |
|
optimizer=None, |
|
input_norm=None, |
|
fourier_loss=None, |
|
fourier_loss_weight=0.0, |
|
lr_scheduler=None, |
|
use_MAE_weight_init=False, |
|
crop_size=-1, |
|
mask_fourier_loss=True, |
|
return_channelwise_embeddings=False, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.mask_ratio = mask_ratio |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.loss = loss |
|
self.optimizer = optimizer |
|
self.input_norm = input_norm |
|
self.fourier_loss = fourier_loss |
|
self.fourier_loss_weight = fourier_loss_weight |
|
self.lr_scheduler = lr_scheduler |
|
self.use_MAE_weight_init = use_MAE_weight_init |
|
self.crop_size = crop_size |
|
self.mask_fourier_loss = mask_fourier_loss |
|
self.return_channelwise_embeddings = return_channelwise_embeddings |
|
|
|
|
|
class MAEModel(PreTrainedModel): |
|
config_class = MAEConfig |
|
|
|
|
|
TOTAL_LOSS = "loss" |
|
RECON_LOSS = "reconstruction_loss" |
|
FOURIER_LOSS = "fourier_loss" |
|
|
|
def __init__(self, config: MAEConfig): |
|
super().__init__(config) |
|
|
|
self.mask_ratio = config.mask_ratio |
|
|
|
|
|
self.encoder = MAEEncoder( |
|
vit_backbone=sincos_positional_encoding_vit( |
|
vit_backbone=vit_small_patch16_256(global_pool="avg") |
|
), |
|
max_in_chans=11, |
|
channel_agnostic=True, |
|
) |
|
self.decoder = CAMAEDecoder( |
|
depth=8, |
|
embed_dim=512, |
|
mlp_ratio=4, |
|
norm_layer=nn.LayerNorm, |
|
num_heads=16, |
|
num_modalities=6, |
|
qkv_bias=True, |
|
tokens_per_modality=256, |
|
) |
|
self.input_norm = torch.nn.Sequential( |
|
Normalizer(), |
|
nn.InstanceNorm2d(None, affine=False, track_running_stats=False), |
|
) |
|
|
|
self.fourier_loss_weight = config.fourier_loss_weight |
|
self.mask_fourier_loss = config.mask_fourier_loss |
|
self.return_channelwise_embeddings = config.return_channelwise_embeddings |
|
self.tokens_per_channel = 256 |
|
|
|
|
|
self.loss = torch.nn.MSELoss(reduction="none") |
|
|
|
self.fourier_loss = FourierLoss(num_multimodal_modalities=6) |
|
if self.fourier_loss_weight > 0 and self.fourier_loss is None: |
|
raise ValueError( |
|
"FourierLoss weight is activated but no fourier_loss was defined in constructor" |
|
) |
|
elif self.fourier_loss_weight >= 1: |
|
raise ValueError( |
|
"FourierLoss weight is too large to do mixing factor, weight should be < 1" |
|
) |
|
|
|
self.patch_size = int(self.encoder.vit_backbone.patch_embed.patch_size[0]) |
|
|
|
|
|
self.encoder_decoder_proj = nn.Linear( |
|
self.encoder.embed_dim, self.decoder.embed_dim, bias=True |
|
) |
|
|
|
self.decoder_pred = nn.Linear( |
|
self.decoder.embed_dim, |
|
self.patch_size**2 |
|
* (1 if self.encoder.channel_agnostic else self.in_chans), |
|
bias=True, |
|
) |
|
|
|
|
|
self.decoder.pos_embeddings = generate_2d_sincos_pos_embeddings( |
|
self.decoder.embed_dim, |
|
length=self.encoder.vit_backbone.patch_embed.grid_size[0], |
|
use_class_token=self.encoder.vit_backbone.cls_token is not None, |
|
num_modality=( |
|
self.decoder.num_modalities if self.encoder.channel_agnostic else 1 |
|
), |
|
) |
|
|
|
if config.use_MAE_weight_init: |
|
w = self.encoder.vit_backbone.patch_embed.proj.weight.data |
|
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
torch.nn.init.normal_(self.encoder.vit_backbone.cls_token, std=0.02) |
|
torch.nn.init.normal_(self.decoder.mask_token, std=0.02) |
|
|
|
self.apply(self._MAE_init_weights) |
|
|
|
def setup(self, stage: str) -> None: |
|
super().setup(stage) |
|
|
|
def _MAE_init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
torch.nn.init.xavier_uniform_(m.weight) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
@staticmethod |
|
def decode_to_reconstruction( |
|
encoder_latent: torch.Tensor, |
|
ind_restore: torch.Tensor, |
|
proj: torch.nn.Module, |
|
decoder: MAEDecoder | CAMAEDecoder, |
|
pred: torch.nn.Module, |
|
) -> torch.Tensor: |
|
"""Feed forward the encoder latent through the decoders necessary projections and transformations.""" |
|
decoder_latent_projection = proj( |
|
encoder_latent |
|
) |
|
decoder_tokens = decoder.forward_masked( |
|
decoder_latent_projection, ind_restore |
|
) |
|
predicted_reconstruction = pred( |
|
decoder_tokens |
|
) |
|
return predicted_reconstruction[:, 1:, :] |
|
|
|
def forward( |
|
self, imgs: torch.Tensor, constant_noise: Union[torch.Tensor, None] = None |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
imgs = self.input_norm(imgs) |
|
latent, mask, ind_restore = self.encoder.forward_masked( |
|
imgs, self.mask_ratio, constant_noise |
|
) |
|
reconstruction = self.decode_to_reconstruction( |
|
latent, |
|
ind_restore, |
|
self.encoder_decoder_proj, |
|
self.decoder, |
|
self.decoder_pred, |
|
) |
|
return latent, reconstruction, mask |
|
|
|
def compute_MAE_loss( |
|
self, |
|
reconstruction: torch.Tensor, |
|
img: torch.Tensor, |
|
mask: torch.Tensor, |
|
) -> Tuple[torch.Tensor, Dict[str, float]]: |
|
"""Computes final loss and returns specific values of component losses for metric reporting.""" |
|
loss_dict = {} |
|
img = self.input_norm(img) |
|
target_flattened = flatten_images( |
|
img, |
|
patch_size=self.patch_size, |
|
channel_agnostic=self.encoder.channel_agnostic, |
|
) |
|
|
|
loss: torch.Tensor = self.loss( |
|
reconstruction, target_flattened |
|
) |
|
loss = loss.mean( |
|
dim=-1 |
|
) |
|
loss = (loss * mask).sum() / mask.sum() |
|
loss_dict[self.RECON_LOSS] = loss.item() |
|
|
|
|
|
if self.fourier_loss_weight > 0: |
|
floss: torch.Tensor = self.fourier_loss(reconstruction, target_flattened) |
|
if not self.mask_fourier_loss: |
|
floss = floss.mean() |
|
else: |
|
floss = floss.mean(dim=-1) |
|
floss = (floss * mask).sum() / mask.sum() |
|
|
|
loss_dict[self.FOURIER_LOSS] = floss.item() |
|
|
|
|
|
if self.fourier_loss_weight > 0: |
|
loss = (1 - self.fourier_loss_weight) * loss + ( |
|
self.fourier_loss_weight * floss |
|
) |
|
return loss, loss_dict |
|
|
|
def training_step(self, batch: TensorDict, batch_idx: int) -> TensorDict: |
|
img = batch["pixels"] |
|
latent, reconstruction, mask = self(img.clone()) |
|
full_loss, loss_dict = self.compute_MAE_loss(reconstruction, img.float(), mask) |
|
return { |
|
"loss": full_loss, |
|
**loss_dict, |
|
} |
|
|
|
def validation_step(self, batch: TensorDict, batch_idx: int) -> TensorDict: |
|
return self.training_step(batch, batch_idx) |
|
|
|
def update_metrics(self, outputs: TensorDict, batch: TensorDict) -> None: |
|
self.metrics["lr"].update(value=self.lr_scheduler.get_last_lr()) |
|
for key, value in outputs.items(): |
|
if key.endswith("loss"): |
|
self.metrics[key].update(value) |
|
|
|
def on_validation_batch_end( |
|
self, |
|
outputs: TensorDict, |
|
batch: TensorDict, |
|
batch_idx: int, |
|
dataloader_idx: int = 0, |
|
) -> None: |
|
super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) |
|
|
|
def predict(self, imgs: torch.Tensor) -> torch.Tensor: |
|
imgs = self.input_norm(imgs) |
|
X = self.encoder.vit_backbone.forward_features( |
|
imgs |
|
) |
|
if self.return_channelwise_embeddings: |
|
N, _, d = X.shape |
|
num_channels = imgs.shape[1] |
|
X_reshaped = X[:, 1:, :].view(N, num_channels, self.tokens_per_channel, d) |
|
pooled_segments = X_reshaped.mean( |
|
dim=2 |
|
) |
|
latent = pooled_segments.view(N, num_channels * d).contiguous() |
|
else: |
|
latent = X[:, 1:, :].mean(dim=1) |
|
return latent |
|
|
|
def save_pretrained(self, save_directory: str, **kwargs): |
|
filename = kwargs.pop("filename", "model.safetensors") |
|
modelpath = f"{save_directory}/{filename}" |
|
self.config.save_pretrained(save_directory) |
|
torch.save({"state_dict": self.state_dict()}, modelpath) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
filename = kwargs.pop("filename", "model.safetensors") |
|
|
|
modelpath = f"{pretrained_model_name_or_path}/{filename}" |
|
config = MAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
state_dict = torch.load(modelpath, map_location="cpu") |
|
model = cls(config, *model_args, **kwargs) |
|
model.load_state_dict(state_dict["state_dict"]) |
|
return model |
|
|