File size: 6,378 Bytes
0e3bdb5
 
e8bb026
 
da1c92f
2d249b4
0e3bdb5
da1c92f
0e3bdb5
e8bb026
 
 
0e3bdb5
e8bb026
 
 
 
 
 
 
 
2d249b4
4c157d1
 
 
 
 
 
 
 
2d249b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e3bdb5
2d249b4
 
 
 
 
 
0e3bdb5
 
 
 
2d249b4
 
 
 
0e3bdb5
 
2d249b4
0e3bdb5
 
11e63b5
0e3bdb5
 
2d249b4
 
0e3bdb5
 
2d249b4
11e63b5
2d249b4
0e3bdb5
 
 
 
2d249b4
11e63b5
2d249b4
0e3bdb5
2d249b4
 
0e3bdb5
 
2d249b4
11e63b5
2d249b4
0e3bdb5
2d249b4
 
0e3bdb5
 
2d249b4
11e63b5
2d249b4
0e3bdb5
2d249b4
 
0e3bdb5
 
2d249b4
11e63b5
2d249b4
0e3bdb5
2d249b4
 
0e3bdb5
 
2d249b4
11e63b5
2d249b4
0e3bdb5
2d249b4
 
 
 
 
 
11e63b5
2d249b4
0e3bdb5
2d249b4
 
0e3bdb5
 
2d249b4
0e3bdb5
2d249b4
 
 
 
 
0e3bdb5
0b69c82
2d249b4
 
0b69c82
 
 
 
 
 
 
 
 
 
 
0e3bdb5
da1c92f
2d249b4
 
 
 
 
 
 
 
0e3bdb5
da1c92f
4c157d1
da1c92f
e8bb026
 
da1c92f
 
0e3bdb5
 
 
 
 
 
 
 
 
 
 
2d249b4
 
0e3bdb5
0b69c82
2d249b4
0b69c82
2d249b4
 
0b69c82
0e3bdb5
2d249b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import torch
from torch import nn, load, hub
from torchvision import models
import helpers.manipulation as h_manipulation
from enum import Enum


class ModelTypes(Enum):
    RESNET = 0
    ALEXNET = 1
    DENSENET = 2
    EFFICIENTNET = 3
    GOOGLENET = 4
    MOBILENET = 5
    SQUEEZENET = 6


class LayerTypes(Enum):
    CONVOLUTIONAL = nn.Conv2d
    LINEAR = nn.Linear

class TransformTypes(Enum):
    PAD = "Pad"
    JITTER = "Jitter"
    RANDOM_SCALE = "Random Scale"
    RANDOM_ROTATE = "Random Rotate"
    AD_JITTER = "Additional Jitter"



_hook_activations = None


# Hook function
def _get_activations():
    """
    Used when registering forward hooks to get layer activations
    :return:
    """
    def hook(model, input, output):
        global _hook_activations
        _hook_activations = output

    return hook


def _get_activation_shape():
    """
    Gets the activation
    :return: A Tuple of the size of the activation
    """
    return _hook_activations.squeeze().size()


def setup_model(model):
    """
    Takes in a model type and creates a standard and robust model. Currently 
    only setups up for the CIFAR10 dataset. Loads the respective standard and 
    robust checkpoints saved in the models directory. Raises a ValueError if 
    model type is not valid.

    :param model: Expects a enum of type ModelTypes to specified model to setup
    :return: model
    """
    curr_dir = os.path.dirname(__file__) + "/../"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    match model:
        case ModelTypes.RESNET:
            base = models.resnet18(
                weights=models.ResNet18_Weights.IMAGENET1K_V1)
            base.fc = nn.Linear(base.fc.in_features, 10)

            base.load_state_dict(
                load(curr_dir + "models/resnet_standard_cifar10.pt", map_location=device))
            model = base
        case ModelTypes.ALEXNET:
            base = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
            base.classifier[6] = nn.Linear(base.classifier[6].in_features, 10)

            base.load_state_dict(
                load(curr_dir + "models/alexnet_standard_cifar10.pt", map_location=device))
            model = base
        case ModelTypes.DENSENET:
            base = models.densenet121(
                weights=models.DenseNet121_Weights.IMAGENET1K_V1)
            base.classifier = nn.Linear(base.classifier.in_features, 10)

            base.load_state_dict(
                load(curr_dir + "models/densenet_standard_cifar10.pt", map_location=device))
            model = base
        case ModelTypes.EFFICIENTNET:
            base = hub.load('NVIDIA/DeepLearningExamples:torchhub',
                            'nvidia_efficientnet_b0', pretrained=True)
            base.classifier.fc = nn.Linear(base.classifier.fc.in_features, 10)

            base.load_state_dict(
                load(curr_dir + "models/efficientnet_standard_cifar10.pt", map_location=device))
            model = base
        case ModelTypes.GOOGLENET:
            base = models.googlenet(
                weights=models.GoogLeNet_Weights.IMAGENET1K_V1)
            base.fc = nn.Linear(base.fc.in_features, 10)

            base.load_state_dict(
                load(curr_dir + "models/googlenet_standard_cifar10.pt", map_location=device))
            model = base
        case ModelTypes.MOBILENET:
            base = models.mobilenet_v2(
                weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
            base.classifier[1] = nn.Linear(base.classifier[1].in_features, 10)

            base.load_state_dict(
                load(curr_dir + "models/mobilenet_standard_cifar10.pt", map_location=device))
            model = base
        case ModelTypes.SQUEEZENET:
            base = models.squeezenet1_0(
                weights=models.SqueezeNet1_0_Weights.IMAGENET1K_V1)
            base.classifier[1] = nn.Conv2d(
                512, 10, kernel_size=(1, 1), stride=(1, 1))

            base.load_state_dict(
                load(curr_dir + "models/squeezenet_standard_cifar10.pt", map_location=device))
            model = base
        case _:
            raise ValueError("Unknown model choice")
    return model


def get_layer_by_name(model, layer_name):
    """
    Gets a layer given a name and model. Raises ValueError if layer is not 
    found in the model.
    :param model: Model to look for layer in
    :param layer_name: Layer name to search
    :return: Layer found
    """
    current_layer = model
    layer_names = layer_name.split('_')  # Split into sub layers

    for name in layer_names:
        if isinstance(current_layer, nn.Module):
            current_layer = getattr(current_layer, name, None)
        else:
            raise ValueError(f"Layer '{layer_name}' not found in the model.")

        if current_layer is None:
            raise ValueError(f"Layer '{layer_name}' not found in the model.")

    return current_layer


def get_feature_map_sizes(model, layers, img=None):
    """
    Gets the feature map sizes, used to dynamically limit values on the 
    interface.
    :param img: Image to use for forward pass
    :param model: Model to pass image through
    :param layers: Layers to grab feature map sizes from
    :return: Feature map sizes
    """
    feature_map_sizes = [None] * len(layers)
    if img is None:
        # TODO Remove this and just generates a blank image of 227 by 227
        img = h_manipulation.create_random_image((227, 227),
              h_manipulation.DatasetNormalizations.CIFAR10_MEAN.value,
              h_manipulation.DatasetNormalizations.CIFAR10_STD.value).clone().unsqueeze(0)
    else:
        img = img.unsqueeze(0)

    # Use GPU if possible
    train_on_gpu = torch.cuda.is_available()
    if train_on_gpu:
        print("Using GPU for generation")
        model = model.cuda()
        img = img.cuda()
    else:
        print("Using CPU for generation")
    model = model.eval()

    # Fake forward pass for activations
    index = 0
    for layer in layers:
        if isinstance(layer, nn.Conv2d):
            hook = layer.register_forward_hook(_get_activations())
            model(img)
            # Activations will have feature map sizes
            feature_map_sizes[index] = _get_activation_shape()
            hook.remove()
        index += 1
    return feature_map_sizes