Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from torch import nn, load, hub | |
from torchvision import models | |
import helpers.manipulation as h_manipulation | |
from enum import Enum | |
class ModelTypes(Enum): | |
RESNET = 0 | |
ALEXNET = 1 | |
DENSENET = 2 | |
EFFICIENTNET = 3 | |
GOOGLENET = 4 | |
MOBILENET = 5 | |
SQUEEZENET = 6 | |
class LayerTypes(Enum): | |
CONVOLUTIONAL = nn.Conv2d | |
LINEAR = nn.Linear | |
class TransformTypes(Enum): | |
PAD = "Pad" | |
JITTER = "Jitter" | |
RANDOM_SCALE = "Random Scale" | |
RANDOM_ROTATE = "Random Rotate" | |
AD_JITTER = "Additional Jitter" | |
_hook_activations = None | |
# Hook function | |
def _get_activations(): | |
""" | |
Used when registering forward hooks to get layer activations | |
:return: | |
""" | |
def hook(model, input, output): | |
global _hook_activations | |
_hook_activations = output | |
return hook | |
def _get_activation_shape(): | |
""" | |
Gets the activation | |
:return: A Tuple of the size of the activation | |
""" | |
return _hook_activations.squeeze().size() | |
def setup_model(model): | |
""" | |
Takes in a model type and creates a standard and robust model. Currently | |
only setups up for the CIFAR10 dataset. Loads the respective standard and | |
robust checkpoints saved in the models directory. Raises a ValueError if | |
model type is not valid. | |
:param model: Expects a enum of type ModelTypes to specified model to setup | |
:return: model | |
""" | |
curr_dir = os.path.dirname(__file__) + "/../" | |
match model: | |
case ModelTypes.RESNET: | |
base = models.resnet18( | |
weights=models.ResNet18_Weights.IMAGENET1K_V1) | |
base.fc = nn.Linear(base.fc.in_features, 10) | |
base.load_state_dict( | |
load(curr_dir + "models/resnet_standard_cifar10.pt")) | |
model = base | |
case ModelTypes.ALEXNET: | |
base = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1) | |
base.classifier[6] = nn.Linear(base.classifier[6].in_features, 10) | |
base.load_state_dict( | |
load(curr_dir + "models/alexnet_standard_cifar10.pt")) | |
model = base | |
case ModelTypes.DENSENET: | |
base = models.densenet121( | |
weights=models.DenseNet121_Weights.IMAGENET1K_V1) | |
base.classifier = nn.Linear(base.classifier.in_features, 10) | |
base.load_state_dict( | |
load(curr_dir + "models/densenet_standard_cifar10.pt")) | |
model = base | |
case ModelTypes.EFFICIENTNET: | |
base = hub.load('NVIDIA/DeepLearningExamples:torchhub', | |
'nvidia_efficientnet_b0', pretrained=True) | |
base.classifier.fc = nn.Linear(base.classifier.fc.in_features, 10) | |
base.load_state_dict( | |
load(curr_dir + "models/efficientnet_standard_cifar10.pt")) | |
model = base | |
case ModelTypes.GOOGLENET: | |
base = models.googlenet( | |
weights=models.GoogLeNet_Weights.IMAGENET1K_V1) | |
base.fc = nn.Linear(base.fc.in_features, 10) | |
base.load_state_dict( | |
load(curr_dir + "models/googlenet_standard_cifar10.pt")) | |
model = base | |
case ModelTypes.MOBILENET: | |
base = models.mobilenet_v2( | |
weights=models.MobileNet_V2_Weights.IMAGENET1K_V1) | |
base.classifier[1] = nn.Linear(base.classifier[1].in_features, 10) | |
base.load_state_dict( | |
load(curr_dir + "models/mobilenet_standard_cifar10.pt")) | |
model = base | |
case ModelTypes.SQUEEZENET: | |
base = models.squeezenet1_0( | |
weights=models.SqueezeNet1_0_Weights.IMAGENET1K_V1) | |
base.classifier[1] = nn.Conv2d( | |
512, 10, kernel_size=(1, 1), stride=(1, 1)) | |
base.load_state_dict( | |
load(curr_dir + "models/squeezenet_standard_cifar10.pt")) | |
model = base | |
case _: | |
raise ValueError("Unknown model choice") | |
return model | |
def get_layer_by_name(model, layer_name): | |
""" | |
Gets a layer given a name and model. Raises ValueError if layer is not | |
found in the model. | |
:param model: Model to look for layer in | |
:param layer_name: Layer name to search | |
:return: Layer found | |
""" | |
current_layer = model | |
layer_names = layer_name.split('_') # Split into sub layers | |
for name in layer_names: | |
if isinstance(current_layer, nn.Module): | |
current_layer = getattr(current_layer, name, None) | |
else: | |
raise ValueError(f"Layer '{layer_name}' not found in the model.") | |
if current_layer is None: | |
raise ValueError(f"Layer '{layer_name}' not found in the model.") | |
return current_layer | |
def get_feature_map_sizes(model, layers, img=None): | |
""" | |
Gets the feature map sizes, used to dynamically limit values on the | |
interface. | |
:param img: Image to use for forward pass | |
:param model: Model to pass image through | |
:param layers: Layers to grab feature map sizes from | |
:return: Feature map sizes | |
""" | |
feature_map_sizes = [None] * len(layers) | |
if img is None: | |
# TODO Remove this and just generates a blank image of 227 by 227 | |
img = h_manipulation.create_random_image((227, 227), | |
h_manipulation.DatasetNormalizations.CIFAR10_MEAN.value, | |
h_manipulation.DatasetNormalizations.CIFAR10_STD.value).clone().unsqueeze(0) | |
else: | |
img = img.unsqueeze(0) | |
# Use GPU if possible | |
train_on_gpu = torch.cuda.is_available() | |
if train_on_gpu: | |
print("Using GPU for generation") | |
model = model.cuda() | |
img = img.cuda() | |
else: | |
print("Using CPU for generation") | |
model = model.eval() | |
# Fake forward pass for activations | |
index = 0 | |
for layer in layers: | |
if isinstance(layer, nn.Conv2d): | |
hook = layer.register_forward_hook(_get_activations()) | |
model(img) | |
# Activations will have feature map sizes | |
feature_map_sizes[index] = _get_activation_shape() | |
hook.remove() | |
index += 1 | |
return feature_map_sizes | |