Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
@@ -19,13 +19,13 @@ def _inner(feat_net, hooks, x):
|
|
19 |
|
20 |
def _get_layers(arch:str, pretrained=True):
|
21 |
"Get the layers and arch for a VGG Model (16 and 19 are supported only)"
|
22 |
-
feat_net = vgg19(pretrained=pretrained)
|
23 |
config = _vgg_config.get(arch)
|
24 |
features = feat_net.features.cuda().eval()
|
25 |
for p in features.parameters(): p.requires_grad=False
|
26 |
return feat_net, [features[i] for i in config]
|
27 |
|
28 |
-
|
29 |
_vgg_config = {
|
30 |
'vgg16' : [1, 11, 18, 25, 20],
|
31 |
'vgg19' : [1, 6, 11, 20, 29, 22]
|
|
|
19 |
|
20 |
def _get_layers(arch:str, pretrained=True):
|
21 |
"Get the layers and arch for a VGG Model (16 and 19 are supported only)"
|
22 |
+
feat_net = vgg19(pretrained=pretrained) if arch.find('9') > 1 else vgg16(pretrained=pretrained)
|
23 |
config = _vgg_config.get(arch)
|
24 |
features = feat_net.features.cuda().eval()
|
25 |
for p in features.parameters(): p.requires_grad=False
|
26 |
return feat_net, [features[i] for i in config]
|
27 |
|
28 |
+
|
29 |
_vgg_config = {
|
30 |
'vgg16' : [1, 11, 18, 25, 20],
|
31 |
'vgg19' : [1, 6, 11, 20, 29, 22]
|