Spaces:
Running
Running
from typing import Optional, Union | |
import torch | |
from torch import device | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as tvm | |
import gc | |
class ResNet50(nn.Module): | |
def __init__( | |
self, | |
pretrained=False, | |
high_res=False, | |
weights=None, | |
dilation=None, | |
freeze_bn=True, | |
anti_aliased=False, | |
early_exit=False, | |
amp=False, | |
) -> None: | |
super().__init__() | |
if dilation is None: | |
dilation = [False, False, False] | |
if anti_aliased: | |
pass | |
else: | |
if weights is not None: | |
self.net = tvm.resnet50( | |
weights=weights, replace_stride_with_dilation=dilation | |
) | |
else: | |
self.net = tvm.resnet50( | |
pretrained=pretrained, replace_stride_with_dilation=dilation | |
) | |
self.high_res = high_res | |
self.freeze_bn = freeze_bn | |
self.early_exit = early_exit | |
self.amp = amp | |
if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): | |
self.amp_dtype = torch.bfloat16 | |
else: | |
self.amp_dtype = torch.float16 | |
def forward(self, x, **kwargs): | |
with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype): | |
net = self.net | |
feats = {1: x} | |
x = net.conv1(x) | |
x = net.bn1(x) | |
x = net.relu(x) | |
feats[2] = x | |
x = net.maxpool(x) | |
x = net.layer1(x) | |
feats[4] = x | |
x = net.layer2(x) | |
feats[8] = x | |
if self.early_exit: | |
return feats | |
x = net.layer3(x) | |
feats[16] = x | |
x = net.layer4(x) | |
feats[32] = x | |
return feats | |
def train(self, mode=True): | |
super().train(mode) | |
if self.freeze_bn: | |
for m in self.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
pass | |
class VGG19(nn.Module): | |
def __init__(self, pretrained=False, amp=False) -> None: | |
super().__init__() | |
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) | |
self.amp = amp | |
if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): | |
self.amp_dtype = torch.bfloat16 | |
else: | |
self.amp_dtype = torch.float16 | |
def forward(self, x, **kwargs): | |
with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype): | |
feats = {} | |
scale = 1 | |
for layer in self.layers: | |
if isinstance(layer, nn.MaxPool2d): | |
feats[scale] = x | |
scale = scale * 2 | |
x = layer(x) | |
return feats | |
class CNNandDinov2(nn.Module): | |
def __init__(self, cnn_kwargs=None, amp=False, use_vgg=False, dinov2_weights=None): | |
super().__init__() | |
if dinov2_weights is None: | |
dinov2_weights = torch.hub.load_state_dict_from_url( | |
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", | |
map_location="cpu", | |
) | |
from .transformer import vit_large | |
vit_kwargs = dict( | |
img_size=518, | |
patch_size=14, | |
init_values=1.0, | |
ffn_layer="mlp", | |
block_chunks=0, | |
) | |
dinov2_vitl14 = vit_large(**vit_kwargs).eval() | |
dinov2_vitl14.load_state_dict(dinov2_weights) | |
cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {} | |
if not use_vgg: | |
self.cnn = ResNet50(**cnn_kwargs) | |
else: | |
self.cnn = VGG19(**cnn_kwargs) | |
self.amp = amp | |
if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): | |
self.amp_dtype = torch.bfloat16 | |
else: | |
self.amp_dtype = torch.float16 | |
if self.amp: | |
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype) | |
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP | |
def train(self, mode: bool = True): | |
return self.cnn.train(mode) | |
def forward(self, x, upsample=False): | |
B, C, H, W = x.shape | |
feature_pyramid = self.cnn(x) | |
if not upsample: | |
with torch.no_grad(): | |
if self.dinov2_vitl14[0].device != x.device: | |
self.dinov2_vitl14[0] = ( | |
self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype) | |
) | |
dinov2_features_16 = self.dinov2_vitl14[0].forward_features( | |
x.to(self.amp_dtype) | |
) | |
features_16 = ( | |
dinov2_features_16["x_norm_patchtokens"] | |
.permute(0, 2, 1) | |
.reshape(B, 1024, H // 14, W // 14) | |
) | |
del dinov2_features_16 | |
feature_pyramid[16] = features_16 | |
return feature_pyramid | |