|
import torch, torchvision, os, collections |
|
from . import parallelfolder, zdataset, renormalize, encoder_net, segmenter |
|
from . import bargraph |
|
|
|
def load_proggan(domain): |
|
|
|
|
|
from . import proggan |
|
weights_filename = dict( |
|
bedroom='proggan_bedroom-d8a89ff1.pth', |
|
church='proggan_churchoutdoor-7e701dd5.pth', |
|
conferenceroom='proggan_conferenceroom-21e85882.pth', |
|
diningroom='proggan_diningroom-3aa0ab80.pth', |
|
kitchen='proggan_kitchen-67f1e16c.pth', |
|
livingroom='proggan_livingroom-5ef336dd.pth', |
|
restaurant='proggan_restaurant-b8578299.pth', |
|
celebhq='proggan_celebhq-620d161c.pth')[domain] |
|
|
|
url = 'http://gandissect.csail.mit.edu/models/' + weights_filename |
|
try: |
|
sd = torch.hub.load_state_dict_from_url(url) |
|
except: |
|
sd = torch.hub.model_zoo.load_url(url) |
|
model = proggan.from_state_dict(sd) |
|
return model |
|
|
|
def load_vgg16(domain='places'): |
|
assert domain == 'places' |
|
model = torchvision.models.vgg16(num_classes=365) |
|
model.features = torch.nn.Sequential(collections.OrderedDict(zip([ |
|
'conv1_1', 'relu1_1', |
|
'conv1_2', 'relu1_2', |
|
'pool1', |
|
'conv2_1', 'relu2_1', |
|
'conv2_2', 'relu2_2', |
|
'pool2', |
|
'conv3_1', 'relu3_1', |
|
'conv3_2', 'relu3_2', |
|
'conv3_3', 'relu3_3', |
|
'pool3', |
|
'conv4_1', 'relu4_1', |
|
'conv4_2', 'relu4_2', |
|
'conv4_3', 'relu4_3', |
|
'pool4', |
|
'conv5_1', 'relu5_1', |
|
'conv5_2', 'relu5_2', |
|
'conv5_3', 'relu5_3', |
|
'pool5'], |
|
model.features))) |
|
model.classifier = torch.nn.Sequential(collections.OrderedDict(zip([ |
|
'fc6', 'relu6', |
|
'drop6', |
|
'fc7', 'relu7', |
|
'drop7', |
|
'fc8a'], |
|
model.classifier))) |
|
baseurl = 'http://gandissect.csail.mit.edu/models/' |
|
url = baseurl + 'vgg16_places365-6e38b568.pth' |
|
try: |
|
sd = torch.hub.load_state_dict_from_url(url) |
|
except: |
|
sd = torch.hub.model_zoo.load_url(url) |
|
|
|
model.load_state_dict(sd) |
|
model.eval() |
|
return model |
|
|
|
|
|
def load_proggan_ablation(modelname): |
|
|
|
|
|
|
|
from . import proggan_ablation |
|
model_classname, weights_filename = { |
|
"equalized-learning-rate": (proggan_ablation.G128_equallr, |
|
"equalized-learning-rate-88ed833d.pth"), |
|
"minibatch-discrimination": (proggan_ablation.G128_minibatch_disc, |
|
"minibatch-discrimination-604c5731.pth"), |
|
"minibatch-stddev": (proggan_ablation.G128_minibatch_disc, |
|
"minibatch-stddev-068bc667.pth"), |
|
"pixelwise-normalization": (proggan_ablation.G128_pixelwisenorm, |
|
"pixelwise-normalization-4da7e9ce.pth"), |
|
"progressive-training": (proggan_ablation.G128_simple, |
|
"progressive-training-70bd90ac.pth"), |
|
|
|
|
|
"small-minibatch": (proggan_ablation.G128_simple, |
|
"small-minibatch-04143d18.pth"), |
|
"wgangp": (proggan_ablation.G128_simple, |
|
"wgangp-beaa509a.pth") |
|
}[modelname] |
|
|
|
url = 'http://gandissect.csail.mit.edu/models/ablations/' + weights_filename |
|
try: |
|
sd = torch.hub.load_state_dict_from_url(url) |
|
except: |
|
sd = torch.hub.model_zoo.load_url(url) |
|
model = model_classname() |
|
model.load_state_dict(sd) |
|
return model |
|
|
|
def load_proggan_inversion(modelname): |
|
|
|
|
|
from . import proggan_ablation |
|
model_classname, weights_filename = { |
|
"church": (encoder_net.HybridLayerNormEncoder, |
|
"church_invert_hybrid_cse-43e52428.pth"), |
|
"bedroom": (encoder_net.HybridLayerNormEncoder, |
|
"bedroom_invert_hybrid_cse-b943528e.pth"), |
|
}[modelname] |
|
|
|
url = 'http://gandissect.csail.mit.edu/models/encoders/' + weights_filename |
|
try: |
|
sd = torch.hub.load_state_dict_from_url(url) |
|
except: |
|
sd = torch.hub.model_zoo.load_url(url) |
|
if 'state_dict' in sd: |
|
sd = sd['state_dict'] |
|
sd = {k.replace('model.', ''): v for k, v in sd.items()} |
|
model = model_classname() |
|
model.load_state_dict(sd) |
|
model.eval() |
|
return model |
|
|
|
|
|
g_datasets = {} |
|
|
|
def load_dataset(domain, split=None, full=False, download=True): |
|
if domain in g_datasets: |
|
return g_datasets[domain] |
|
if domain == 'places': |
|
if split is None: |
|
split = 'val' |
|
dirname = 'datasets/microimagenet' |
|
if download and not os.path.exists(dirname): |
|
os.makedirs('datasets', exist_ok=True) |
|
torchvision.datasets.utils.download_and_extract_archive( |
|
'http://gandissect.csail.mit.edu/datasets/' + |
|
'microimagenet.zip', |
|
'datasets') |
|
return parallelfolder.ParallelImageFolders([dirname], |
|
classification=True, |
|
shuffle=True, |
|
transform=g_places_transform) |
|
else: |
|
|
|
if split is None: |
|
split = 'train' |
|
dirname = os.path.join( |
|
'datasets', 'lsun' if full else 'minilsun', domain) |
|
dirname += '_' + split |
|
if download and not full and not os.path.exists('datasets/minilsun'): |
|
os.makedirs('datasets', exist_ok=True) |
|
torchvision.datasets.utils.download_and_extract_archive( |
|
'http://gandissect.csail.mit.edu/datasets/minilsun.zip', |
|
'datasets', |
|
md5='a67a898673a559db95601314b9b51cd5') |
|
return parallelfolder.ParallelImageFolders([dirname], |
|
shuffle=True, |
|
transform=g_transform) |
|
|
|
g_transform = torchvision.transforms.Compose([ |
|
torchvision.transforms.Resize(256), |
|
torchvision.transforms.CenterCrop(256), |
|
torchvision.transforms.ToTensor(), |
|
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
|
|
|
g_places_transform = torchvision.transforms.Compose([ |
|
torchvision.transforms.Resize(256), |
|
torchvision.transforms.CenterCrop(224), |
|
torchvision.transforms.ToTensor(), |
|
renormalize.NORMALIZER['imagenet']]) |
|
|
|
def load_segmenter(segmenter_name='netpqc'): |
|
'''Loads the segementer.''' |
|
all_parts = ('p' in segmenter_name) |
|
quad_seg = ('q' in segmenter_name) |
|
textures = ('x' in segmenter_name) |
|
colors = ('c' in segmenter_name) |
|
|
|
segmodels = [] |
|
segmodels.append(segmenter.UnifiedParsingSegmenter(segsizes=[256], |
|
all_parts=all_parts, |
|
segdiv=('quad' if quad_seg else None))) |
|
if textures: |
|
segmenter.ensure_segmenter_downloaded('datasets/segmodel', 'texture') |
|
segmodels.append(segmenter.SemanticSegmenter( |
|
segvocab="texture", segarch=("resnet18dilated", "ppm_deepsup"))) |
|
if colors: |
|
segmenter.ensure_segmenter_downloaded('datasets/segmodel', 'color') |
|
segmodels.append(segmenter.SemanticSegmenter( |
|
segvocab="color", segarch=("resnet18dilated", "ppm_deepsup"))) |
|
if len(segmodels) == 1: |
|
segmodel = segmodels[0] |
|
else: |
|
segmodel = segmenter.MergedSegmenter(segmodels) |
|
seglabels = [l for l, c in segmodel.get_label_and_category_names()[0]] |
|
segcatlabels = segmodel.get_label_and_category_names()[0] |
|
return segmodel, seglabels, segcatlabels |
|
|
|
def graph_conceptcatlist(conceptcatlist, cats = None, print_nums = False, **kwargs): |
|
count = collections.defaultdict(int) |
|
catcount = collections.defaultdict(int) |
|
for c in conceptcatlist: |
|
count[c] += 1 |
|
for c in count.keys(): |
|
catcount[c[1]] += 1 |
|
if cats is None: |
|
cats = ['object', 'part', 'material', 'texture', 'color'] |
|
catorder = dict((c, i) for i, c in enumerate(cats)) |
|
sorted_labels = sorted(count.keys(), |
|
key=lambda x: (catorder[x[1]], -count[x])) |
|
sorted_labels |
|
tot_num = 0 |
|
if print_nums: |
|
for k in sorted_labels: |
|
print(count[k]) |
|
tot_num += count[k] |
|
print("Total unique concepts: {}".format(tot_num)) |
|
return bargraph.make_svg_bargraph( |
|
[label for label, cat in sorted_labels], |
|
[count[k] for k in sorted_labels], |
|
[(c, catcount[c]) for c in cats], **kwargs) |
|
|
|
def save_concept_graph(filename, conceptlist): |
|
svg = graph_conceptlist(conceptlist, file_header=True) |
|
with open(filename, 'w') as f: |
|
f.write(svg) |
|
|
|
def save_conceptcat_graph(filename, conceptcatlist): |
|
svg = graph_conceptcatlist(conceptcatlist, barheight=80, file_header=True) |
|
with open(filename, 'w') as f: |
|
f.write(svg) |
|
|
|
def load_test_image(imgnum, split, model, full=False): |
|
if split == 'gan': |
|
with torch.no_grad(): |
|
generator = load_proggan(model) |
|
z = zdataset.z_sample_for_model(generator, size=(imgnum + 1) |
|
)[imgnum] |
|
z = z[None] |
|
return generator(z), z |
|
assert split in ['train', 'val'] |
|
ds = load_dataset(model, split, full=full) |
|
return ds[imgnum][0][None], None |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|