SFM_Inference_Demo / models_Fault.py
Anirudh Bhalekar
added models and util folder
a3f0d6c
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm.models.vision_transformer
import numpy as np
from util.msssim import MSSSIM
from util.pos_embed import get_2d_sincos_pos_embed
from util.variable_pos_embed import interpolate_pos_embed_variable
class FlexiblePatchEmbed(nn.Module):
""" 2D Image to Patch Embedding that handles variable input sizes """
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, bias=True):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.num_patches = (img_size // patch_size) ** 2 # default number of patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
def forward(self, x):
B, C, H, W = x.shape
# Calculate number of patches dynamically
self.num_patches = (H // self.patch_size) * (W // self.patch_size)
x = self.proj(x).flatten(2).transpose(1, 2) # BCHW -> BNC
return x
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
""" Vision Transformer with support for global average pooling
"""
def __init__(self, global_pool=False,**kwargs):
super(VisionTransformer, self).__init__(**kwargs)
self.global_pool = global_pool
self.decoder = DecoderCup(in_channels=[self.embed_dim,256,128,64])
self.segmentation_head = SegmentationHead(
in_channels=64,
out_channels=self.num_classes,
kernel_size=1
)
if self.global_pool:
norm_layer = kwargs['norm_layer']
embed_dim = kwargs['embed_dim']
self.fc_norm = norm_layer(embed_dim)
del self.norm # remove the original norm
def interpolate_pos_encoding(self, x, h, w):
"""
Interpolate positional embeddings for arbitrary input sizes
"""
npatch = x.shape[1] - 1 # subtract 1 for cls token
N = self.pos_embed.shape[1] - 1 # original number of patches
if npatch == N and h == w:
return self.pos_embed
# Use the new variable position embedding utility
return interpolate_pos_embed_variable(self.pos_embed, h, w, cls_token=True)
def generate_mask(self,input_tensor, ratio):
mask = torch.zeros_like(input_tensor)
indices = torch.randperm(mask.size(3)//16)[:int(mask.size(3)//16 * ratio)]
sorted_indices = torch.sort(indices)[0]
for i in range(0, len(sorted_indices)):
mask[:, :, :, sorted_indices[i]*16:(sorted_indices[i]+1)*16] = 1
return mask
def forward_features(self, x):
B,C,H,W = x.shape
# Handle padding for non-16-divisible images
patch_size = self.patch_embed.patch_size
pad_h = (patch_size - H % patch_size) % patch_size
pad_w = (patch_size - W % patch_size) % patch_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
H_padded, W_padded = H + pad_h, W + pad_w
else:
H_padded, W_padded = H, W
img = x
x = self.patch_embed(x)
_H, _W = H_padded // patch_size, W_padded // patch_size
# Add class token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# Add interpolated positional embeddings
pos_embed = self.interpolate_pos_encoding(x, _H, _W)
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
x = self.decoder(x[:, 1:, :], img)
x = self.segmentation_head(x)
return x
def forward(self, x):
x = self.forward_features(x)
return x
def inference(self, x):
x = self.forward_features(x)
x = F.softmax(x, dim=1)
return x
class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
bn = nn.BatchNorm2d(out_channels)
super(Conv2dReLU, self).__init__(conv, bn, relu)
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
skip_channels=0,
use_batchnorm=True,
):
super().__init__()
self.conv1 = Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
def forward(self, x, skip=None):
x = self.up(x)
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x
class SegmentationHead(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=1, upsampling=1):
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=0)
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
super().__init__(conv2d, upsampling)
class DecoderCup(nn.Module):
def __init__(self,in_channels=[1024,256,128,64]):
super().__init__()
head_channels = 512
self.conv_more = Conv2dReLU(
1,
32,
kernel_size=3,
padding=1,
use_batchnorm=True,
)
skip_channels=[0,0,0,32]
out_channels=[256,128,64,64]
blocks = [
DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
]
self.blocks = nn.ModuleList(blocks)
def forward(self, hidden_states, img, features=None):
B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
x = hidden_states.permute(0, 2, 1)
x = x.contiguous().view(B, hidden, h, w)
skip_channels=[None,None,None,self.conv_more(img)]
for i, decoder_block in enumerate(self.blocks):
x = decoder_block(x, skip=skip_channels[i])
return x
def forward_loss(imgs, pred):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
loss1f = torch.nn.MSELoss()
loss1 = loss1f(imgs, pred)
loss2f = MSSSIM()
loss2 = loss2f(imgs, pred)
a = 0.5
loss = (1-a)*loss1+a*loss2
return loss
def weighted_cross_entropy(pred, target):
"""
Compute the weighted cross entropy loss.
NEED VERIFICATION
"""
# Function to compute weighted cross entropy loss
# target: [batch, channel, s, s]
# pred: [batch, channel, s, s]
#print('pred shape ', pred.shape)
#print('target shape ', target.shape)
#print('--------------')
#print('sums of pred', torch.sum(pred))
#print('sums of target', torch.sum(target))
# beta is the fraction of non-fault pixels in the target (i.e the zeroes in the target)
beta = torch.mean(target) # fraction of fault pixels
beta = 1 - beta # fraction of non-fault pixels
beta = torch.clamp(beta, min=0.01, max=0.99) # avoid division by zero
#print('beta', beta)
# Compute the weighted cross entropy loss
loss = -(beta * target * torch.log(pred + 1e-8) + (1-beta) * (1 - target) * torch.log(1 - pred + 1e-8))
return torch.mean(loss)
def mae_vit_small_patch16(**kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=6, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
# Replace with flexible patch embedding
model.patch_embed = FlexiblePatchEmbed(
img_size=kwargs.get('img_size', 224),
patch_size=16,
in_chans=kwargs.get('in_chans', 3),
embed_dim=768
)
return model
def vit_base_patch16(**kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
# Replace with flexible patch embedding
model.patch_embed = FlexiblePatchEmbed(
img_size=kwargs.get('img_size', 224),
patch_size=16,
in_chans=kwargs.get('in_chans', 3),
embed_dim=768
)
return model
def vit_large_patch16(**kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
# Replace with flexible patch embedding
model.patch_embed = FlexiblePatchEmbed(
img_size=kwargs.get('img_size', 224),
patch_size=16,
in_chans=kwargs.get('in_chans', 3),
embed_dim=1024
)
return model
def vit_huge_patch14(**kwargs):
model = VisionTransformer(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
# Replace with flexible patch embedding
model.patch_embed = FlexiblePatchEmbed(
img_size=kwargs.get('img_size', 224),
patch_size=14,
in_chans=kwargs.get('in_chans', 3),
embed_dim=1280
)
return model