Spaces:
Runtime error
Runtime error
File size: 6,378 Bytes
0e3bdb5 e8bb026 da1c92f 2d249b4 0e3bdb5 da1c92f 0e3bdb5 e8bb026 0e3bdb5 e8bb026 2d249b4 4c157d1 2d249b4 0e3bdb5 2d249b4 0e3bdb5 2d249b4 0e3bdb5 2d249b4 0e3bdb5 11e63b5 0e3bdb5 2d249b4 0e3bdb5 2d249b4 11e63b5 2d249b4 0e3bdb5 2d249b4 11e63b5 2d249b4 0e3bdb5 2d249b4 0e3bdb5 2d249b4 11e63b5 2d249b4 0e3bdb5 2d249b4 0e3bdb5 2d249b4 11e63b5 2d249b4 0e3bdb5 2d249b4 0e3bdb5 2d249b4 11e63b5 2d249b4 0e3bdb5 2d249b4 0e3bdb5 2d249b4 11e63b5 2d249b4 0e3bdb5 2d249b4 11e63b5 2d249b4 0e3bdb5 2d249b4 0e3bdb5 2d249b4 0e3bdb5 2d249b4 0e3bdb5 0b69c82 2d249b4 0b69c82 0e3bdb5 da1c92f 2d249b4 0e3bdb5 da1c92f 4c157d1 da1c92f e8bb026 da1c92f 0e3bdb5 2d249b4 0e3bdb5 0b69c82 2d249b4 0b69c82 2d249b4 0b69c82 0e3bdb5 2d249b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import os
import torch
from torch import nn, load, hub
from torchvision import models
import helpers.manipulation as h_manipulation
from enum import Enum
class ModelTypes(Enum):
RESNET = 0
ALEXNET = 1
DENSENET = 2
EFFICIENTNET = 3
GOOGLENET = 4
MOBILENET = 5
SQUEEZENET = 6
class LayerTypes(Enum):
CONVOLUTIONAL = nn.Conv2d
LINEAR = nn.Linear
class TransformTypes(Enum):
PAD = "Pad"
JITTER = "Jitter"
RANDOM_SCALE = "Random Scale"
RANDOM_ROTATE = "Random Rotate"
AD_JITTER = "Additional Jitter"
_hook_activations = None
# Hook function
def _get_activations():
"""
Used when registering forward hooks to get layer activations
:return:
"""
def hook(model, input, output):
global _hook_activations
_hook_activations = output
return hook
def _get_activation_shape():
"""
Gets the activation
:return: A Tuple of the size of the activation
"""
return _hook_activations.squeeze().size()
def setup_model(model):
"""
Takes in a model type and creates a standard and robust model. Currently
only setups up for the CIFAR10 dataset. Loads the respective standard and
robust checkpoints saved in the models directory. Raises a ValueError if
model type is not valid.
:param model: Expects a enum of type ModelTypes to specified model to setup
:return: model
"""
curr_dir = os.path.dirname(__file__) + "/../"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
match model:
case ModelTypes.RESNET:
base = models.resnet18(
weights=models.ResNet18_Weights.IMAGENET1K_V1)
base.fc = nn.Linear(base.fc.in_features, 10)
base.load_state_dict(
load(curr_dir + "models/resnet_standard_cifar10.pt", map_location=device))
model = base
case ModelTypes.ALEXNET:
base = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
base.classifier[6] = nn.Linear(base.classifier[6].in_features, 10)
base.load_state_dict(
load(curr_dir + "models/alexnet_standard_cifar10.pt", map_location=device))
model = base
case ModelTypes.DENSENET:
base = models.densenet121(
weights=models.DenseNet121_Weights.IMAGENET1K_V1)
base.classifier = nn.Linear(base.classifier.in_features, 10)
base.load_state_dict(
load(curr_dir + "models/densenet_standard_cifar10.pt", map_location=device))
model = base
case ModelTypes.EFFICIENTNET:
base = hub.load('NVIDIA/DeepLearningExamples:torchhub',
'nvidia_efficientnet_b0', pretrained=True)
base.classifier.fc = nn.Linear(base.classifier.fc.in_features, 10)
base.load_state_dict(
load(curr_dir + "models/efficientnet_standard_cifar10.pt", map_location=device))
model = base
case ModelTypes.GOOGLENET:
base = models.googlenet(
weights=models.GoogLeNet_Weights.IMAGENET1K_V1)
base.fc = nn.Linear(base.fc.in_features, 10)
base.load_state_dict(
load(curr_dir + "models/googlenet_standard_cifar10.pt", map_location=device))
model = base
case ModelTypes.MOBILENET:
base = models.mobilenet_v2(
weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
base.classifier[1] = nn.Linear(base.classifier[1].in_features, 10)
base.load_state_dict(
load(curr_dir + "models/mobilenet_standard_cifar10.pt", map_location=device))
model = base
case ModelTypes.SQUEEZENET:
base = models.squeezenet1_0(
weights=models.SqueezeNet1_0_Weights.IMAGENET1K_V1)
base.classifier[1] = nn.Conv2d(
512, 10, kernel_size=(1, 1), stride=(1, 1))
base.load_state_dict(
load(curr_dir + "models/squeezenet_standard_cifar10.pt", map_location=device))
model = base
case _:
raise ValueError("Unknown model choice")
return model
def get_layer_by_name(model, layer_name):
"""
Gets a layer given a name and model. Raises ValueError if layer is not
found in the model.
:param model: Model to look for layer in
:param layer_name: Layer name to search
:return: Layer found
"""
current_layer = model
layer_names = layer_name.split('_') # Split into sub layers
for name in layer_names:
if isinstance(current_layer, nn.Module):
current_layer = getattr(current_layer, name, None)
else:
raise ValueError(f"Layer '{layer_name}' not found in the model.")
if current_layer is None:
raise ValueError(f"Layer '{layer_name}' not found in the model.")
return current_layer
def get_feature_map_sizes(model, layers, img=None):
"""
Gets the feature map sizes, used to dynamically limit values on the
interface.
:param img: Image to use for forward pass
:param model: Model to pass image through
:param layers: Layers to grab feature map sizes from
:return: Feature map sizes
"""
feature_map_sizes = [None] * len(layers)
if img is None:
# TODO Remove this and just generates a blank image of 227 by 227
img = h_manipulation.create_random_image((227, 227),
h_manipulation.DatasetNormalizations.CIFAR10_MEAN.value,
h_manipulation.DatasetNormalizations.CIFAR10_STD.value).clone().unsqueeze(0)
else:
img = img.unsqueeze(0)
# Use GPU if possible
train_on_gpu = torch.cuda.is_available()
if train_on_gpu:
print("Using GPU for generation")
model = model.cuda()
img = img.cuda()
else:
print("Using CPU for generation")
model = model.eval()
# Fake forward pass for activations
index = 0
for layer in layers:
if isinstance(layer, nn.Conv2d):
hook = layer.register_forward_hook(_get_activations())
model(img)
# Activations will have feature map sizes
feature_map_sizes[index] = _get_activation_shape()
hook.remove()
index += 1
return feature_map_sizes
|