Spaces:
Runtime error
Runtime error
File size: 1,616 Bytes
cbdb39b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch
from os import path
from sys import version_info
from collections import OrderedDict
from torch.utils.model_zoo import load_url
# Download the VGG-19 model and fix the layer names
print("Downloading the VGG-19 model")
sd = load_url("https://web.eecs.umich.edu/~justincj/models/vgg19-d01eb7cb.pth")
map = {'classifier.1.weight':u'classifier.0.weight', 'classifier.1.bias':u'classifier.0.bias', 'classifier.4.weight':u'classifier.3.weight', 'classifier.4.bias':u'classifier.3.bias'}
sd = OrderedDict([(map[k] if k in map else k,v) for k,v in sd.items()])
torch.save(sd, path.join("models", "vgg19-d01eb7cb.pth"))
# Download the VGG-16 model and fix the layer names
print("Downloading the VGG-16 model")
sd = load_url("https://web.eecs.umich.edu/~justincj/models/vgg16-00b39a1b.pth")
map = {'classifier.1.weight':u'classifier.0.weight', 'classifier.1.bias':u'classifier.0.bias', 'classifier.4.weight':u'classifier.3.weight', 'classifier.4.bias':u'classifier.3.bias'}
sd = OrderedDict([(map[k] if k in map else k,v) for k,v in sd.items()])
torch.save(sd, path.join("models", "vgg16-00b39a1b.pth"))
# Download the NIN model
print("Downloading the NIN model")
if version_info[0] < 3:
import urllib
urllib.URLopener().retrieve("https://raw.githubusercontent.com/ProGamerGov/pytorch-nin/master/nin_imagenet.pth", path.join("models", "nin_imagenet.pth"))
else:
import urllib.request
urllib.request.urlretrieve("https://raw.githubusercontent.com/ProGamerGov/pytorch-nin/master/nin_imagenet.pth", path.join("models", "nin_imagenet.pth"))
print("All models have been successfully downloaded")
|