Spaces:
Sleeping
Sleeping
import mmpretrain | |
import torch | |
from torch import nn | |
from collections.abc import Iterable | |
from mmpretrain.models.utils.attention import MultiheadAttention | |
# This holds model instantiation functions by (dataset_name, model_name) tuple keys | |
MODEL_REGISTRY = {} | |
class ClsModel(nn.Module): | |
dataset_name: str | |
model_name: str | |
def __init__(self, dataset_name: str, model_name: str, device: str) -> None: | |
super().__init__() | |
self.dataset_name = dataset_name | |
self.model_name = model_name | |
self.device = device | |
def head_features(self): | |
pass | |
def num_classes(self): | |
pass | |
def forward(self, x): | |
""" | |
x: [B, 3 (RGB), H, W] image (float) [0,1] | |
returns: [B, C] class logits | |
""" | |
raise NotImplementedError("Forward not implemented for base class") | |
class TimmPretrainModelWrapper(ClsModel): | |
""" | |
Calls data preprocessing for model before entering forward | |
""" | |
def __init__(self, model: nn.Module, transform, dataset_name: str, model_name: str, device: str) -> None: | |
super().__init__(dataset_name, model_name, device) | |
self.model = model | |
self.transform = transform | |
def final_linear_layer(self): | |
try: | |
testing_head = self.model.head | |
head = True | |
except: | |
head = False | |
if head: | |
if isinstance(self.model.head, torch.nn.Linear): | |
return self.model.head | |
else: | |
return self.model.head.fc | |
else: | |
return self.model.fc | |
def head_features(self): | |
return self.final_linear_layer.in_features | |
def num_classes(self): | |
return self.final_linear_layer.out_features | |
def head(self, feats): | |
return self.model.head((feats,)) | |
def head_matrices(self): | |
return self.final_linear_layer.weight, self.final_linear_layer.bias | |
def forward(self, x, return_features=False): | |
x = self.transform(x) | |
if return_features: | |
feats = self.model.forward_features(x) | |
logits = self.model.forward_head(feats, pre_logits=True) | |
try: | |
preds = self.model.fc(logits) # convnet, | |
except: | |
preds = self.model.head(logits) # vit | |
return preds, logits | |
else: | |
return self.model(x) | |
class MMPretrainModelWrapper(ClsModel): | |
""" | |
Calls data preprocessing for model before entering forward | |
""" | |
def __init__(self, model: nn.Module, dataset_name: str, model_name: str, device: str) -> None: | |
super().__init__(dataset_name, model_name, device) | |
self.model = model | |
def final_linear_layer(self): | |
return self.model.head.fc | |
def head_features(self): | |
return self.final_linear_layer.in_features | |
def num_classes(self): | |
return self.final_linear_layer.out_features | |
def head(self, feats): | |
return self.model.head((feats,)) | |
def head_matrices(self): | |
return self.final_linear_layer.weight, self.final_linear_layer.bias | |
def forward(self, x, return_features=False): | |
# Data preprocessor expects 0-255 range, but we don't want to cast to proper | |
# uint8 because we want to maintain differentiability | |
x = x * 255. | |
x = self.model.data_preprocessor({"inputs": x})["inputs"] | |
if return_features: | |
feats = self.model.extract_feat(x) | |
preds = self.model.head(feats) | |
if isinstance(feats, Iterable): | |
feats = feats[-1] | |
return preds, feats | |
else: | |
return self.model(x) | |
class MMPretrainVisualTransformerWrapper(MMPretrainModelWrapper): | |
def __init__(self, model, dataset_name: str, model_name: str, device: str) -> None: | |
super().__init__(model, dataset_name, model_name, device) | |
attn_layers = [] | |
def find_mha(m: nn.Module): | |
if isinstance(m, MultiheadAttention): | |
attn_layers.append(m) | |
model.apply(find_mha) | |
self.attn_layers = attn_layers | |
def final_linear_layer(self): | |
return self.model.head.layers.head | |
def get_attention_maps(self, x): | |
clean_forwards = [] | |
attention_maps = [] | |
for attn_layer in self.attn_layers: | |
clean_forward = attn_layer.forward | |
clean_forwards.append(clean_forward) | |
def scaled_dot_prod_attn(query, | |
key, | |
value, | |
attn_mask=None, | |
dropout_p=0., | |
scale=None, | |
is_causal=False): | |
scale = scale or query.size(-1)**0.5 | |
if is_causal and attn_mask is not None: | |
attn_mask = torch.ones( | |
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0) | |
if attn_mask is not None and attn_mask.dtype == torch.bool: | |
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) | |
attn_weight = query @ key.transpose(-2, -1) / scale | |
if attn_mask is not None: | |
attn_weight += attn_mask | |
attn_weight = torch.softmax(attn_weight, dim=-1) | |
attention_maps.append(attn_weight) | |
attn_weight = torch.dropout(attn_weight, dropout_p, True) | |
return attn_weight @ value | |
attn_layer.scaled_dot_product_attention = scaled_dot_prod_attn | |
ret_val = super().forward(x, False) | |
for attn_layer, clean_forward in zip(self.attn_layers, clean_forwards): | |
attn_layer.forward = clean_forward | |
return attention_maps | |
def register_mmcls_model(config_name, dataset_name, model_name, | |
wrapper_class=MMPretrainModelWrapper): | |
def instantiate_model(device): | |
model = mmpretrain.get_model(config_name, pretrained=True, device=device) | |
wrapper = wrapper_class(model, dataset_name, model_name, device) | |
return wrapper | |
MODEL_REGISTRY[(dataset_name, model_name)] = instantiate_model | |
def register_default_models(): | |
register_mmcls_model("resnet18_8xb16_cifar10", "cifar10", "resnet18") | |
register_mmcls_model("resnet34_8xb16_cifar10", "cifar10", "resnet34") | |
register_mmcls_model("resnet18_8xb32_in1k", "imagenet", "resnet18") | |
register_mmcls_model("resnet50_8xb16_cifar100", "cifar100", "resnet50") | |
register_mmcls_model("resnet50_8xb32_in1k", "imagenet", "resnet50") | |
register_mmcls_model("densenet121_3rdparty_in1k", "imagenet", "densenet121") | |
register_mmcls_model("deit-small_4xb256_in1k", "imagenet", "deit_small", | |
wrapper_class=MMPretrainVisualTransformerWrapper) | |
register_mmcls_model("vit-base-p16_32xb128-mae_in1k", "imagenet", "vit_base", | |
wrapper_class=MMPretrainVisualTransformerWrapper) | |
def get_model(dataset_name, model_name, device): | |
""" | |
Returns instance of model pretrained with specified dataset | |
""" | |
return MODEL_REGISTRY[(dataset_name, model_name)](device).eval() |