eval
Browse files- evaluation/README.md +16 -0
- evaluation/__init__.py +0 -0
- evaluation/check_fc8_labels.py +61 -0
- evaluation/download_evaluation_data.py +39 -0
- evaluation/eval_deception_score.py +196 -0
- evaluation/logger.py +80 -0
- evaluation/run_deception_score_vgg_16_wikiart.sh +39 -0
evaluation/README.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Style transfer quaintitative evaluation using Deception Score
|
2 |
+
|
3 |
+
### How to calculate Deception Score:
|
4 |
+
|
5 |
+
1. Run `./download_evaluation_data.py` to download the weights for artist classification model.
|
6 |
+
2. Set `results_dir` variable in `eval_deception_score.py:92` to point to the directory with stylized images.
|
7 |
+
All images generated by one method must be in one directory.
|
8 |
+
Image filenames must be in format `"{content_name}_stylized_{artist_name}.jpg"`, for example: `"Places366_val_00000510_stylized_vincent-van-gogh.jpg"`.
|
9 |
+
3. Run `./run_deception_score_vgg_16_wikiart.sh`
|
10 |
+
4. Read results in the log file in `./logs` directory.
|
11 |
+
|
12 |
+
|
13 |
+
### How to evaluate your own model:
|
14 |
+
|
15 |
+
- Download validation sets from MSCOCO ([val2017.zip](http://images.cocodataset.org/zips/val2017.zip)) and Places365 ([val_large.tar](http://data.csail.mit.edu/places/places365/val_large.tar)).
|
16 |
+
- To compare with deception score reported in the paper run your stylization model on the content images listed in [eval_paths_700_val.json](evaluation_data/eval_paths_700_val.json).
|
evaluation/__init__.py
ADDED
File without changes
|
evaluation/check_fc8_labels.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2018 Artsiom Sanakoyeu and Dmytro Kotovenko
|
2 |
+
#
|
3 |
+
# This file is part of Adaptive Style Transfer
|
4 |
+
#
|
5 |
+
# Adaptive Style Transfer is free software: you can redistribute it and/or modify
|
6 |
+
# it under the terms of the GNU General Public License as published by
|
7 |
+
# the Free Software Foundation, either version 3 of the License, or
|
8 |
+
# (at your option) any later version.
|
9 |
+
#
|
10 |
+
# Adaptive Style Transfer is distributed in the hope that it will be useful,
|
11 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13 |
+
# GNU General Public License for more details.
|
14 |
+
#
|
15 |
+
# You should have received a copy of the GNU General Public License
|
16 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
17 |
+
|
18 |
+
import pandas as pd
|
19 |
+
import h5py
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
ARTISTS = ['claude-monet',
|
23 |
+
'paul-cezanne',
|
24 |
+
'el-greco',
|
25 |
+
'paul-gauguin',
|
26 |
+
'samuel-peploe',
|
27 |
+
'vincent-van-gogh',
|
28 |
+
'edvard-munch',
|
29 |
+
'pablo-picasso',
|
30 |
+
'berthe-morisot',
|
31 |
+
'ernst-ludwig-kirchner',
|
32 |
+
'jackson-pollock',
|
33 |
+
'wassily-kandinsky',
|
34 |
+
'nicholas-roerich']
|
35 |
+
|
36 |
+
|
37 |
+
def get_artist_labels_wikiart(artists=ARTISTS):
|
38 |
+
"""
|
39 |
+
Get mapping of artist name to class label
|
40 |
+
"""
|
41 |
+
split_df = pd.read_hdf('evaluation_data/split.hdf5')
|
42 |
+
|
43 |
+
labels = dict()
|
44 |
+
|
45 |
+
for artist_id in artists:
|
46 |
+
artist_id_in_split = artist_id
|
47 |
+
print artist_id
|
48 |
+
cur_df = split_df[split_df.index.str.startswith(artist_id_in_split)]
|
49 |
+
assert len(cur_df)
|
50 |
+
if not np.all(cur_df.index.str.startswith(artist_id_in_split + '_')):
|
51 |
+
print cur_df[~cur_df.index.str.startswith(artist_id_in_split + '_')]
|
52 |
+
assert False
|
53 |
+
|
54 |
+
print '===='
|
55 |
+
labels[artist_id] = cur_df['label'][0]
|
56 |
+
return labels
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
|
61 |
+
print get_artist_labels_wikiart(ARTISTS)
|
evaluation/download_evaluation_data.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
from __future__ import print_function
|
3 |
+
|
4 |
+
import requests
|
5 |
+
import os
|
6 |
+
|
7 |
+
from torchvision.datasets.utils import download_url
|
8 |
+
|
9 |
+
API_ENDPOINT = 'https://cloud-api.yandex.net/v1/disk/public/resources/download?public_key={}'
|
10 |
+
|
11 |
+
EVALUATION_DATA_URL = 'https://yadi.sk/d/A2CBqSGuJ0M_XA'
|
12 |
+
|
13 |
+
|
14 |
+
def get_real_direct_link(sharing_link):
|
15 |
+
pk_request = requests.get(API_ENDPOINT.format(sharing_link))
|
16 |
+
|
17 |
+
return pk_request.json()['href']
|
18 |
+
|
19 |
+
|
20 |
+
def unzip(path, target_dir='.'):
|
21 |
+
import zipfile
|
22 |
+
with zipfile.ZipFile(path, 'r') as zip_ref:
|
23 |
+
zip_ref.extractall(target_dir)
|
24 |
+
|
25 |
+
|
26 |
+
def main():
|
27 |
+
root = "."
|
28 |
+
link = get_real_direct_link(EVALUATION_DATA_URL)
|
29 |
+
filename = 'evaluation_data.zip'
|
30 |
+
print('Downloadng data (1Gb). This may take a while...')
|
31 |
+
download_url(link, root, filename, None)
|
32 |
+
print('Unzipping...')
|
33 |
+
unzip(os.path.join(root, filename), target_dir='.')
|
34 |
+
print('Done.')
|
35 |
+
|
36 |
+
|
37 |
+
if __name__ == '__main__':
|
38 |
+
main()
|
39 |
+
|
evaluation/eval_deception_score.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2018 Artsiom Sanakoyeu and Dmytro Kotovenko
|
2 |
+
#
|
3 |
+
# This file is part of Adaptive Style Transfer
|
4 |
+
#
|
5 |
+
# Adaptive Style Transfer is free software: you can redistribute it and/or modify
|
6 |
+
# it under the terms of the GNU General Public License as published by
|
7 |
+
# the Free Software Foundation, either version 3 of the License, or
|
8 |
+
# (at your option) any later version.
|
9 |
+
#
|
10 |
+
# Adaptive Style Transfer is distributed in the hope that it will be useful,
|
11 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13 |
+
# GNU General Public License for more details.
|
14 |
+
#
|
15 |
+
# You should have received a copy of the GNU General Public License
|
16 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import os
|
20 |
+
import sys
|
21 |
+
from pprint import pformat
|
22 |
+
import glob
|
23 |
+
import numpy as np
|
24 |
+
import pandas as pd
|
25 |
+
import re
|
26 |
+
|
27 |
+
from feature_extractor.feature_extractor import SlimFeatureExtractor
|
28 |
+
from logger import Logger
|
29 |
+
from check_fc8_labels import get_artist_labels_wikiart
|
30 |
+
|
31 |
+
|
32 |
+
def parse_one_or_list(str_value):
|
33 |
+
if str_value is not None:
|
34 |
+
if str_value.lower() == 'none':
|
35 |
+
str_value = None
|
36 |
+
elif ',' in str_value:
|
37 |
+
str_value = str_value.split(',')
|
38 |
+
return str_value
|
39 |
+
|
40 |
+
|
41 |
+
def parse_list(str_value):
|
42 |
+
if ',' in str_value:
|
43 |
+
str_value = str_value.split(',')
|
44 |
+
else:
|
45 |
+
str_value = [str_value]
|
46 |
+
return str_value
|
47 |
+
|
48 |
+
|
49 |
+
def parse_none(str_value):
|
50 |
+
if str_value is not None:
|
51 |
+
if str_value.lower() == 'none' or str_value == "":
|
52 |
+
str_value = None
|
53 |
+
return str_value
|
54 |
+
|
55 |
+
|
56 |
+
def parse_args(argv):
|
57 |
+
parser = argparse.ArgumentParser()
|
58 |
+
parser.add_argument('-net', '--net', help='network type',
|
59 |
+
choices=['vgg_16', 'vgg_16_multihead'], default='vgg_16')
|
60 |
+
parser.add_argument('-log', '--log-path', help='log path', type=str,
|
61 |
+
default='/tmp/res.txt'
|
62 |
+
)
|
63 |
+
parser.add_argument('-s', '--snapshot_path', type=str,
|
64 |
+
default='vgg_16.ckpt')
|
65 |
+
parser.add_argument('-b', '--batch-size', type=int, default=64)
|
66 |
+
parser.add_argument('--method', type=str, default='ours')
|
67 |
+
parser.add_argument('--num_classes', type=int, default=624)
|
68 |
+
parser.add_argument('--dataset', type=str, default='wikiart', choices=['wikiart'])
|
69 |
+
args = parser.parse_args(argv)
|
70 |
+
args = vars(args)
|
71 |
+
return args
|
72 |
+
|
73 |
+
|
74 |
+
def create_slim_extractor(cli_params):
|
75 |
+
extractor_class = SlimFeatureExtractor
|
76 |
+
extractor_ = extractor_class(cli_params['net'], cli_params['snapshot_path'],
|
77 |
+
should_restore_classifier=True,
|
78 |
+
gpu_memory_fraction=0.95,
|
79 |
+
vgg_16_heads=None if cli_params['net'] != 'vgg_16_multihead' else {'artist_id': cli_params['num_classes']})
|
80 |
+
return extractor_
|
81 |
+
|
82 |
+
|
83 |
+
classification_layer = {
|
84 |
+
'vgg_16': 'vgg_16/fc8',
|
85 |
+
'vgg_16_multihead': 'vgg_16/fc8_artist_id'
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
def run(extractor, classification_layer, images_df, batch_size=64, logger=Logger()):
|
90 |
+
images_df = images_df.copy()
|
91 |
+
if len(images_df) == 0:
|
92 |
+
print 'No images found!'
|
93 |
+
return -1, 0, 0
|
94 |
+
probs = extractor.extract(images_df['image_path'].values, [classification_layer],
|
95 |
+
verbose=1, batch_size=batch_size)
|
96 |
+
images_df['predicted_class'] = np.argmax(probs, axis=1).tolist()
|
97 |
+
is_correct = images_df['label'] == images_df['predicted_class']
|
98 |
+
accuracy = float(is_correct.sum()) / len(images_df)
|
99 |
+
|
100 |
+
logger.log('Num images: {}'.format(len(images_df)))
|
101 |
+
logger.log('Correctly classified: {}/{}'.format(is_correct.sum(), len(images_df)))
|
102 |
+
logger.log('Accuracy: {:.5f}'.format(accuracy))
|
103 |
+
logger.log('\n===')
|
104 |
+
return accuracy, is_correct.sum(), len(images_df)
|
105 |
+
|
106 |
+
|
107 |
+
# image filenames must be in format "{content_name}_stylized_{artist_name}.jpg"
|
108 |
+
# uncomment methods which you want to evaluate and set the paths to the folders with the stylized images
|
109 |
+
results_dir = {
|
110 |
+
'ours': 'path/to/our/stylizations',
|
111 |
+
# 'gatys': 'path/to/gatys_stylizations',
|
112 |
+
# 'cyclegan': '',
|
113 |
+
# 'adain': '',
|
114 |
+
# 'johnson': '',
|
115 |
+
# 'wct': '',
|
116 |
+
# 'real_wiki_test': os.path.expanduser('~/workspace/wikiart/images_square_227x227') # uncomment to test on real images from wikiart test set
|
117 |
+
}
|
118 |
+
|
119 |
+
|
120 |
+
style_2_image_name = {u'berthe-morisot': u'Morisot-1886-the-lesson-in-the-garden',
|
121 |
+
u'claude-monet': u'monet-1914-water-lilies-37.jpg!HD',
|
122 |
+
u'edvard-munch': u'Munch-the-scream-1893',
|
123 |
+
u'el-greco': u'el-greco-the-resurrection-1595.jpg!HD',
|
124 |
+
u'ernst-ludwig-kirchner': u'Kirchner-1913-street-berlin.jpg!HD',
|
125 |
+
u'jackson-pollock': u'Pollock-number-one-moma-November-31-1950-1950',
|
126 |
+
u'nicholas-roerich': u'nicholas-roerich_mongolia-campaign-of-genghis-khan',
|
127 |
+
u'pablo-picasso': u'weeping-woman-1937',
|
128 |
+
u'paul-cezanne': u'still-life-with-apples-1894.jpg!HD',
|
129 |
+
u'paul-gauguin': u'Gauguin-the-seed-of-the-areoi-1892',
|
130 |
+
u'samuel-peploe': u'peploe-ile-de-brehat-1911-1',
|
131 |
+
u'vincent-van-gogh': u'vincent-van-gogh_road-with-cypresses-1890',
|
132 |
+
u'wassily-kandinsky': u'Kandinsky-improvisation-28-second-version-1912'}
|
133 |
+
|
134 |
+
|
135 |
+
artist_2_label_wikiart = get_artist_labels_wikiart()
|
136 |
+
|
137 |
+
|
138 |
+
def get_images_df(dataset, method, artist_slug):
|
139 |
+
images_dir = results_dir[method]
|
140 |
+
paths = glob.glob(os.path.join(images_dir, '*.jpg')) + glob.glob(os.path.join(images_dir, '*.png'))
|
141 |
+
# print paths
|
142 |
+
assert len(paths) or method.startswith('real')
|
143 |
+
|
144 |
+
if not method.startswith('real'):
|
145 |
+
cur_style_paths = [x for x in paths if re.match('.*_stylized_({}|{}).(jpg|png)'.format(artist_slug, style_2_image_name[artist_slug]), os.path.basename(x)) is not None]
|
146 |
+
elif method == 'real_wiki_test':
|
147 |
+
# use only images from the test set
|
148 |
+
split_df = pd.read_hdf(os.path.expanduser('evaluation_data/split.hdf5'))
|
149 |
+
split_df['image_id'] = split_df.index
|
150 |
+
df = split_df[split_df['split'] == 'test']
|
151 |
+
df['artist_id'] = df['image_id'].apply(lambda x: x.split('_', 1)[0])
|
152 |
+
df['image_path'] = df['image_id'].apply(lambda x: os.path.join(results_dir['real_wiki_test'], x + '.png'))
|
153 |
+
cur_style_paths = df.loc[df['artist_id'] == artist_slug, 'image_path'].values
|
154 |
+
|
155 |
+
df = pd.DataFrame(index=[os.path.basename(x).split('_stylized_', 1)[0].rstrip('.') for x in
|
156 |
+
cur_style_paths], data={'image_path': cur_style_paths, 'artist': artist_slug})
|
157 |
+
|
158 |
+
df['label'] = artist_2_label_wikiart[artist_slug]
|
159 |
+
return df
|
160 |
+
|
161 |
+
|
162 |
+
def sprint_stats(stats):
|
163 |
+
msg = ''
|
164 |
+
msg += 'artist\t accuracy\t is_correct\t total\n'
|
165 |
+
for key in sorted(stats.keys()):
|
166 |
+
msg += key + '\t {:.5f}\t {}\t \t{}\n'.format(*stats[key])
|
167 |
+
return msg
|
168 |
+
|
169 |
+
|
170 |
+
if __name__ == '__main__':
|
171 |
+
import sys
|
172 |
+
|
173 |
+
args = parse_args(sys.argv[1:])
|
174 |
+
|
175 |
+
if not os.path.exists(os.path.dirname(args['log_path'])):
|
176 |
+
os.makedirs(os.path.dirname(args['log_path']))
|
177 |
+
logger = Logger(args['log_path'])
|
178 |
+
print 'Snapshot: {}'.format(args['snapshot_path'])
|
179 |
+
extractor = create_slim_extractor(args)
|
180 |
+
classification_layer = classification_layer[args['net']]
|
181 |
+
|
182 |
+
stats = dict()
|
183 |
+
assert artist_2_label_wikiart is not None
|
184 |
+
for artist in artist_2_label_wikiart.keys():
|
185 |
+
print('Method:', args['method'])
|
186 |
+
logger.log('Artist: {}'.format(artist))
|
187 |
+
images_df = get_images_df(dataset=args['dataset'], method=args['method'], artist_slug=artist)
|
188 |
+
acc, num_is_correct, num_total = run(extractor, classification_layer, images_df,
|
189 |
+
batch_size=args['batch_size'], logger=logger)
|
190 |
+
stats[artist] = (acc, num_is_correct, num_total)
|
191 |
+
|
192 |
+
logger.log('{}'.format(pformat(args)))
|
193 |
+
print 'Images dir:', results_dir[args['method']]
|
194 |
+
logger.log('===\n\n')
|
195 |
+
logger.log(args['method'])
|
196 |
+
logger.log('{}'.format(sprint_stats(stats)))
|
evaluation/logger.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2018 Artsiom Sanakoyeu and Dmytro Kotovenko
|
2 |
+
#
|
3 |
+
# This file is part of Adaptive Style Transfer
|
4 |
+
#
|
5 |
+
# Adaptive Style Transfer is free software: you can redistribute it and/or modify
|
6 |
+
# it under the terms of the GNU General Public License as published by
|
7 |
+
# the Free Software Foundation, either version 3 of the License, or
|
8 |
+
# (at your option) any later version.
|
9 |
+
#
|
10 |
+
# Adaptive Style Transfer is distributed in the hope that it will be useful,
|
11 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13 |
+
# GNU General Public License for more details.
|
14 |
+
#
|
15 |
+
# You should have received a copy of the GNU General Public License
|
16 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
17 |
+
|
18 |
+
import sys
|
19 |
+
import os
|
20 |
+
import subprocess
|
21 |
+
|
22 |
+
|
23 |
+
class Logger(object):
|
24 |
+
def __init__(self, filepath=None, mode='w'):
|
25 |
+
self.file = None
|
26 |
+
self.filepath = filepath
|
27 |
+
if filepath is not None:
|
28 |
+
self.file = open(filepath, mode=mode, buffering=0)
|
29 |
+
|
30 |
+
def __enter__(self):
|
31 |
+
return self
|
32 |
+
|
33 |
+
def log(self, msg, should_print=True):
|
34 |
+
if should_print:
|
35 |
+
print '[LOG] {}'.format(msg)
|
36 |
+
if self.file is not None:
|
37 |
+
self.file.write('{}\n'.format(msg))
|
38 |
+
|
39 |
+
def write(self, msg):
|
40 |
+
sys.__stdout__.write(msg)
|
41 |
+
if self.file is not None:
|
42 |
+
self.file.write(msg)
|
43 |
+
self.file.flush()
|
44 |
+
|
45 |
+
def close(self):
|
46 |
+
if self.file is not None:
|
47 |
+
self.file.close()
|
48 |
+
|
49 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
50 |
+
self.close()
|
51 |
+
|
52 |
+
|
53 |
+
def log(logger, msg, should_print=True):
|
54 |
+
if logger:
|
55 |
+
logger.log(msg, should_print)
|
56 |
+
else:
|
57 |
+
if should_print:
|
58 |
+
print msg
|
59 |
+
|
60 |
+
|
61 |
+
class Tee:
|
62 |
+
def __init__(self, log_path):
|
63 |
+
self.prev_stdout_descriptor = os.dup(sys.stdout.fileno())
|
64 |
+
self.prev_stderr_descriptor = os.dup(sys.stderr.fileno())
|
65 |
+
|
66 |
+
tee = subprocess.Popen(['tee', log_path], stdin=subprocess.PIPE)
|
67 |
+
os.dup2(tee.stdin.fileno(), sys.stdout.fileno())
|
68 |
+
os.dup2(tee.stdin.fileno(), sys.stderr.fileno())
|
69 |
+
|
70 |
+
def close(self):
|
71 |
+
os.dup2(self.prev_stdout_descriptor, sys.stdout.fileno())
|
72 |
+
os.close(self.prev_stdout_descriptor)
|
73 |
+
os.dup2(self.prev_stderr_descriptor, sys.stderr.fileno())
|
74 |
+
os.close(self.prev_stderr_descriptor)
|
75 |
+
|
76 |
+
def __enter__(self):
|
77 |
+
return self
|
78 |
+
|
79 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
80 |
+
self.close()
|
evaluation/run_deception_score_vgg_16_wikiart.sh
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# Copyright (C) 2018 Artsiom Sanakoyeu and Dmytro Kotovenko
|
4 |
+
#
|
5 |
+
# This file is part of Adaptive Style Transfer
|
6 |
+
#
|
7 |
+
# Adaptive Style Transfer is free software: you can redistribute it and/or modify
|
8 |
+
# it under the terms of the GNU General Public License as published by
|
9 |
+
# the Free Software Foundation, either version 3 of the License, or
|
10 |
+
# (at your option) any later version.
|
11 |
+
#
|
12 |
+
# Adaptive Style Transfer is distributed in the hope that it will be useful,
|
13 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
14 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
15 |
+
# GNU General Public License for more details.
|
16 |
+
#
|
17 |
+
# You should have received a copy of the GNU General Public License
|
18 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
19 |
+
|
20 |
+
set -e
|
21 |
+
LOG_DIR=logs
|
22 |
+
mkdir -p ${LOG_DIR}
|
23 |
+
|
24 |
+
NET=vgg_16_multihead
|
25 |
+
|
26 |
+
METHODS=( "ours" )
|
27 |
+
#METHODS=( "ours" "real_wiki_test" )
|
28 |
+
|
29 |
+
for method in ${METHODS[@]}
|
30 |
+
do
|
31 |
+
echo $method
|
32 |
+
python eval_deception_score.py \
|
33 |
+
-net=${NET} \
|
34 |
+
-s="evaluation_data/model.ckpt-790000" \
|
35 |
+
-log=${LOG_DIR}/deception_score_${method}.txt \
|
36 |
+
--method=$method \
|
37 |
+
--num_classes=624 \
|
38 |
+
--dataset="wikiart"
|
39 |
+
done
|