Earth Observation
Foundation Model
Remote Sensing
TerraFM / terrafm.py
msohaildanish's picture
Upload model weights
dc6ae70 verified
# ------------------------------------------------------------------------------
# This file includes code copied and adapted from DINO:
# - DINO (https://github.com/facebookresearch/dino)
#
# ------------------------------------------------------------------------------
import random
import math
import torch
import torch.nn as nn
from torch import Tensor
from functools import partial
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, return_attention=False):
y, attn = self.attn(self.norm1(x))
if return_attention:
return attn
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
def __init__(
self,
img_size: int,
embed_dim: int,
patch_size: int,
in_chans_s1: int,
in_chans_s2: int,
):
super().__init__()
attn_dim = embed_dim*3 # from Panopticon design
self.img_size = img_size
self.patch_size = patch_size
num_patches = (img_size // patch_size) * (img_size // patch_size)
self.num_patches = num_patches
self.conv2d_s2_l2a = nn.Conv2d(in_chans_s2, attn_dim, kernel_size=patch_size, stride=patch_size)
self.conv2d_s2_l1c = nn.Conv2d(in_chans_s2, attn_dim, kernel_size=patch_size, stride=patch_size)
self.conv2d_s1 = nn.Conv2d(in_chans_s1, attn_dim, kernel_size=patch_size, stride=patch_size)
self.projection = TokenProjection(embed_dim=embed_dim, attn_dim=attn_dim)
self.s2_l2a_embed = nn.Parameter(torch.zeros(1, attn_dim))
self.s2_l1c_embed = nn.Parameter(torch.zeros(1, attn_dim))
self.s1_embed = nn.Parameter(torch.zeros(1, attn_dim))
self.attn_dim = attn_dim
def forward(self, x12: Tensor, is_l2a: bool = False) -> Tensor:
B,C,W,H = x12.shape
device, dtype = x12.device, x12.dtype
B = len(x12)
if C == 2:
x = self.conv2d_s1(x12).flatten(2).transpose(1, 2)
x += self.s1_embed
elif is_l2a:
x = self.conv2d_s2_l2a(x12).flatten(2).transpose(1, 2)
x += self.s2_l2a_embed
else:
x = self.conv2d_s2_l1c(x12).flatten(2).transpose(1, 2)
x += self.s2_l1c_embed
x = self.projection(x)
return x
class TokenProjection(nn.Module):
def __init__(self, embed_dim: int, attn_dim: int):
super().__init__()
self.proj1 = nn.Linear(attn_dim, attn_dim, bias=False)
self.norm_input = nn.LayerNorm(attn_dim)
self.proj2 = nn.Linear(attn_dim, attn_dim)
self.proj3 = nn.Linear(attn_dim, embed_dim)
def forward(self, x: Tensor) -> Tensor:
"""
Applies a sequence of linear projections used for Case 1 & N in modality augmentation.
Steps:
1. proj1 is shared between Case 1 and Case N (acts like value projection in attention).
2. Applies LayerNorm to stabilize training and normalize features.
3. In Case N, proj2 is applied after the weighted mean operation.
4. proj3 projects to the final embedding dimension.
Args:
tokens (Tensor): Input tensor of shape [B, N, input_dim], where
B = batch size, N = number of tokens.
Returns:
Tensor: Projected output of shape [B, N, final_dim].
"""
x = self.proj1(x) #V in corss attn
x = self.norm_input(x)
x = self.proj2(x)
x = self.proj3(x) #final projection
return x
class TerraFM(nn.Module):
def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
super().__init__()
self.num_features = self.embed_dim = embed_dim
self.patch_embed = PatchEmbed(
img_size=img_size[0], patch_size=patch_size, in_chans_s1=2, in_chans_s2=12, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# Classifier head
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
class_pos_embed = self.pos_embed[:, 0]
patch_pos_embed = self.pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_embed.patch_size
h0 = h // self.patch_embed.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def prepare_tokens(self, x):
B, nc, w, h = x.shape
x = self.patch_embed(x) # patch linear embedding
# add the [CLS] token to the embed patch tokens
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# add positional encoding to each token
x = x + self.interpolate_pos_encoding(x, w, h)
return self.pos_drop(x)
def forward_features(self, x):
return self.forward(x)
def forward(self, x):
x = self.prepare_tokens(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 0]
def get_last_selfattention(self, x):
x = self.prepare_tokens(x)
for i, blk in enumerate(self.blocks):
if i < len(self.blocks) - 1:
x = blk(x)
else:
# return attention of the last block
return blk(x, return_attention=True)
def get_intermediate_layers(self, x, n=1,
return_class_token = False,
norm=False,
):
x = self.prepare_tokens(x)
# we return the output tokens from the `n` last blocks
output = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if len(self.blocks) - i <= n:
output.append(x)
# output.append(self.norm(x))
if norm:
output = [self.norm(out) for out in output]
class_tokens = [out[:, 0] for out in output]
output = [out[:, 1:] for out in output]
if return_class_token:
return tuple(zip(output, class_tokens))
return output
def extract_feature(self, images, return_h_w=True, out_indices=[3, 5, 7, 11]):
x = self.prepare_tokens(images)
output = []
h, w = int(images.shape[2] / self.patch_embed.patch_size), int(images.shape[3] / self.patch_embed.patch_size)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in out_indices:
out = x[:, 1:]
out = self.norm(out)
B, _, C = out.shape
out = (
out.reshape(B, h, w, C)
.permute(0, 3, 1, 2)
.contiguous()
)
output.append(out)
return output
def terrafm_base(patch_size=16, **kwargs):
model = TerraFM(
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def terrafm_large(patch_size=16, **kwargs):
model = TerraFM(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model