AK391 commited on
Commit
7788a23
·
1 Parent(s): 5f7f727

example files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +42 -0
  2. bin/analyze_errors.py +316 -0
  3. bin/blur_predicts.py +57 -0
  4. bin/calc_dataset_stats.py +88 -0
  5. bin/debug/analyze_overlapping_masks.sh +31 -0
  6. bin/evaluate_predicts.py +79 -0
  7. bin/evaluator_example.py +76 -0
  8. bin/extract_masks.py +63 -0
  9. bin/filter_sharded_dataset.py +69 -0
  10. bin/gen_debug_mask_dataset.py +61 -0
  11. bin/gen_mask_dataset.py +130 -0
  12. bin/gen_mask_dataset_hydra.py +124 -0
  13. bin/gen_outpainting_dataset.py +88 -0
  14. bin/make_checkpoint.py +79 -0
  15. bin/mask_example.py +14 -0
  16. bin/paper_runfiles/blur_tests.sh +37 -0
  17. bin/paper_runfiles/env.sh +8 -0
  18. bin/paper_runfiles/find_best_checkpoint.py +54 -0
  19. bin/paper_runfiles/generate_test_celeba-hq.sh +17 -0
  20. bin/paper_runfiles/generate_test_ffhq.sh +17 -0
  21. bin/paper_runfiles/generate_test_paris.sh +17 -0
  22. bin/paper_runfiles/generate_test_paris_256.sh +17 -0
  23. bin/paper_runfiles/generate_val_test.sh +28 -0
  24. bin/paper_runfiles/predict_inner_features.sh +20 -0
  25. bin/paper_runfiles/update_test_data_stats.sh +30 -0
  26. bin/predict.py +89 -0
  27. bin/predict_inner_features.py +119 -0
  28. bin/report_from_tb.py +83 -0
  29. bin/sample_from_dataset.py +87 -0
  30. bin/side_by_side.py +76 -0
  31. bin/split_tar.py +22 -0
  32. bin/train.py +72 -0
  33. canvas.png +0 -0
  34. conda_env.yml +165 -0
  35. configs/analyze_mask_errors.yaml +7 -0
  36. configs/data_gen/gen_segm_dataset1.yaml +25 -0
  37. configs/data_gen/gen_segm_dataset3.yaml +25 -0
  38. configs/data_gen/random_medium_256.yaml +33 -0
  39. configs/data_gen/random_medium_512.yaml +33 -0
  40. configs/data_gen/random_thick_256.yaml +33 -0
  41. configs/data_gen/random_thick_512.yaml +33 -0
  42. configs/data_gen/random_thin_256.yaml +25 -0
  43. configs/data_gen/random_thin_512.yaml +25 -0
  44. configs/data_gen/segm_256.yaml +27 -0
  45. configs/data_gen/segm_512.yaml +27 -0
  46. configs/data_gen/sr_256.yaml +25 -0
  47. configs/data_gen/whydra/location/mml-ws01-celeba-hq.yaml +5 -0
  48. configs/data_gen/whydra/location/mml-ws01-ffhq.yaml +5 -0
  49. configs/data_gen/whydra/location/mml-ws01-paris.yaml +5 -0
  50. configs/data_gen/whydra/location/mml7-places.yaml +5 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("gdown https://drive.google.com/uc?id=1-95IOJ-2y9BtmABiffIwndPqNZD_gLnV")
3
+ os.system("unzip big-lama.zip")
4
+ import cv2
5
+ import paddlehub as hub
6
+ import gradio as gr
7
+ import torch
8
+ from PIL import Image, ImageOps
9
+ import numpy as np
10
+ os.mkdir("data")
11
+ os.mkdir("dataout")
12
+ model = hub.Module(name='U2Net')
13
+ def infer(img,mask,option):
14
+ img = ImageOps.contain(img, (700,700))
15
+ width, height = img.size
16
+ img.save("./data/data.png")
17
+ if option == "automatic (U2net)":
18
+ result = model.Segmentation(
19
+ images=[cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)],
20
+ paths=None,
21
+ batch_size=1,
22
+ input_size=320,
23
+ output_dir='output',
24
+ visualization=True)
25
+ im = Image.fromarray(result[0]['mask'])
26
+ else:
27
+ mask = mask.resize((width,height))
28
+ im = mask
29
+ im.save("./data/data_mask.png")
30
+ os.system('python predict.py model.path=/home/user/app/big-lama/ indir=/home/user/app/data/ outdir=/home/user/app/dataout/ device=cpu')
31
+ return "./dataout/data_mask.png",im
32
+
33
+ inputs = [gr.inputs.Image(type='pil', label="Original Image"),gr.inputs.Image(type='pil',source="canvas", label="Mask",invert_colors=True),gr.inputs.Radio(choices=["automatic (U2net)","manual"], type="value", default="manual", label="Masking option")]
34
+ outputs = [gr.outputs.Image(type="file",label="output"),gr.outputs.Image(type="pil",label="Mask")]
35
+ title = "LaMa Image Inpainting"
36
+ description = "Gradio demo for LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Masks are generated by U^2net"
37
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.07161' target='_blank'>Resolution-robust Large Mask Inpainting with Fourier Convolutions</a> | <a href='https://github.com/saic-mdal/lama' target='_blank'>Github Repo</a></p>"
38
+ examples = [
39
+ ['person512.png',"canvas.png","automatic (U2net)"],
40
+ ['person512.png',"maskexam.png","manual"]
41
+ ]
42
+ gr.Interface(infer, inputs, outputs, title=title, description=description, article=article, examples=examples).launch(enable_queue=True,cache_examples=True)
bin/analyze_errors.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import cv2
3
+ import numpy as np
4
+ import sklearn
5
+ import torch
6
+ import os
7
+ import pickle
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+ from joblib import Parallel, delayed
11
+
12
+ from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset, load_image
13
+ from saicinpainting.evaluation.losses.fid.inception import InceptionV3
14
+ from saicinpainting.evaluation.utils import load_yaml
15
+ from saicinpainting.training.visualizers.base import visualize_mask_and_images
16
+
17
+
18
+ def draw_score(img, score):
19
+ img = np.transpose(img, (1, 2, 0))
20
+ cv2.putText(img, f'{score:.2f}',
21
+ (40, 40),
22
+ cv2.FONT_HERSHEY_SIMPLEX,
23
+ 1,
24
+ (0, 1, 0),
25
+ thickness=3)
26
+ img = np.transpose(img, (2, 0, 1))
27
+ return img
28
+
29
+
30
+ def save_global_samples(global_mask_fnames, mask2real_fname, mask2fake_fname, out_dir, real_scores_by_fname, fake_scores_by_fname):
31
+ for cur_mask_fname in global_mask_fnames:
32
+ cur_real_fname = mask2real_fname[cur_mask_fname]
33
+ orig_img = load_image(cur_real_fname, mode='RGB')
34
+ fake_img = load_image(mask2fake_fname[cur_mask_fname], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
35
+ mask = load_image(cur_mask_fname, mode='L')[None, ...]
36
+
37
+ draw_score(orig_img, real_scores_by_fname.loc[cur_real_fname, 'real_score'])
38
+ draw_score(fake_img, fake_scores_by_fname.loc[cur_mask_fname, 'fake_score'])
39
+
40
+ cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=mask, fake=fake_img),
41
+ keys=['image', 'fake'],
42
+ last_without_mask=True)
43
+ cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8')
44
+ cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR)
45
+ cv2.imwrite(os.path.join(out_dir, os.path.splitext(os.path.basename(cur_mask_fname))[0] + '.jpg'),
46
+ cur_grid)
47
+
48
+
49
+ def save_samples_by_real(worst_best_by_real, mask2fake_fname, fake_info, out_dir):
50
+ for real_fname in worst_best_by_real.index:
51
+ worst_mask_path = worst_best_by_real.loc[real_fname, 'worst']
52
+ best_mask_path = worst_best_by_real.loc[real_fname, 'best']
53
+ orig_img = load_image(real_fname, mode='RGB')
54
+ worst_mask_img = load_image(worst_mask_path, mode='L')[None, ...]
55
+ worst_fake_img = load_image(mask2fake_fname[worst_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
56
+ best_mask_img = load_image(best_mask_path, mode='L')[None, ...]
57
+ best_fake_img = load_image(mask2fake_fname[best_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
58
+
59
+ draw_score(orig_img, worst_best_by_real.loc[real_fname, 'real_score'])
60
+ draw_score(worst_fake_img, worst_best_by_real.loc[real_fname, 'worst_score'])
61
+ draw_score(best_fake_img, worst_best_by_real.loc[real_fname, 'best_score'])
62
+
63
+ cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=np.zeros_like(worst_mask_img),
64
+ worst_mask=worst_mask_img, worst_img=worst_fake_img,
65
+ best_mask=best_mask_img, best_img=best_fake_img),
66
+ keys=['image', 'worst_mask', 'worst_img', 'best_mask', 'best_img'],
67
+ rescale_keys=['worst_mask', 'best_mask'],
68
+ last_without_mask=True)
69
+ cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8')
70
+ cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR)
71
+ cv2.imwrite(os.path.join(out_dir,
72
+ os.path.splitext(os.path.basename(real_fname))[0] + '.jpg'),
73
+ cur_grid)
74
+
75
+ fig, (ax1, ax2) = plt.subplots(1, 2)
76
+ cur_stat = fake_info[fake_info['real_fname'] == real_fname]
77
+ cur_stat['fake_score'].hist(ax=ax1)
78
+ cur_stat['real_score'].hist(ax=ax2)
79
+ fig.tight_layout()
80
+ fig.savefig(os.path.join(out_dir,
81
+ os.path.splitext(os.path.basename(real_fname))[0] + '_scores.png'))
82
+ plt.close(fig)
83
+
84
+
85
+ def extract_overlapping_masks(mask_fnames, cur_i, fake_scores_table, max_overlaps_n=2):
86
+ result_pairs = []
87
+ result_scores = []
88
+ mask_fname_a = mask_fnames[cur_i]
89
+ mask_a = load_image(mask_fname_a, mode='L')[None, ...] > 0.5
90
+ cur_score_a = fake_scores_table.loc[mask_fname_a, 'fake_score']
91
+ for mask_fname_b in mask_fnames[cur_i + 1:]:
92
+ mask_b = load_image(mask_fname_b, mode='L')[None, ...] > 0.5
93
+ if not np.any(mask_a & mask_b):
94
+ continue
95
+ cur_score_b = fake_scores_table.loc[mask_fname_b, 'fake_score']
96
+ result_pairs.append((mask_fname_a, mask_fname_b))
97
+ result_scores.append(cur_score_b - cur_score_a)
98
+ if len(result_pairs) >= max_overlaps_n:
99
+ break
100
+ return result_pairs, result_scores
101
+
102
+
103
+ def main(args):
104
+ config = load_yaml(args.config)
105
+
106
+ latents_dir = os.path.join(args.outpath, 'latents')
107
+ os.makedirs(latents_dir, exist_ok=True)
108
+ global_worst_dir = os.path.join(args.outpath, 'global_worst')
109
+ os.makedirs(global_worst_dir, exist_ok=True)
110
+ global_best_dir = os.path.join(args.outpath, 'global_best')
111
+ os.makedirs(global_best_dir, exist_ok=True)
112
+ worst_best_by_best_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_max')
113
+ os.makedirs(worst_best_by_best_worst_score_diff_max_dir, exist_ok=True)
114
+ worst_best_by_best_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_min')
115
+ os.makedirs(worst_best_by_best_worst_score_diff_min_dir, exist_ok=True)
116
+ worst_best_by_real_best_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_max')
117
+ os.makedirs(worst_best_by_real_best_score_diff_max_dir, exist_ok=True)
118
+ worst_best_by_real_best_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_min')
119
+ os.makedirs(worst_best_by_real_best_score_diff_min_dir, exist_ok=True)
120
+ worst_best_by_real_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_max')
121
+ os.makedirs(worst_best_by_real_worst_score_diff_max_dir, exist_ok=True)
122
+ worst_best_by_real_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_min')
123
+ os.makedirs(worst_best_by_real_worst_score_diff_min_dir, exist_ok=True)
124
+
125
+ if not args.only_report:
126
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
127
+ inception_model = InceptionV3([block_idx]).eval().cuda()
128
+
129
+ dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
130
+
131
+ real2vector_cache = {}
132
+
133
+ real_features = []
134
+ fake_features = []
135
+
136
+ orig_fnames = []
137
+ mask_fnames = []
138
+ mask2real_fname = {}
139
+ mask2fake_fname = {}
140
+
141
+ for batch_i, batch in enumerate(dataset):
142
+ orig_img_fname = dataset.img_filenames[batch_i]
143
+ mask_fname = dataset.mask_filenames[batch_i]
144
+ fake_fname = dataset.pred_filenames[batch_i]
145
+ mask2real_fname[mask_fname] = orig_img_fname
146
+ mask2fake_fname[mask_fname] = fake_fname
147
+
148
+ cur_real_vector = real2vector_cache.get(orig_img_fname, None)
149
+ if cur_real_vector is None:
150
+ with torch.no_grad():
151
+ in_img = torch.from_numpy(batch['image'][None, ...]).cuda()
152
+ cur_real_vector = inception_model(in_img)[0].squeeze(-1).squeeze(-1).cpu().numpy()
153
+ real2vector_cache[orig_img_fname] = cur_real_vector
154
+
155
+ pred_img = torch.from_numpy(batch['inpainted'][None, ...]).cuda()
156
+ cur_fake_vector = inception_model(pred_img)[0].squeeze(-1).squeeze(-1).cpu().numpy()
157
+
158
+ real_features.append(cur_real_vector)
159
+ fake_features.append(cur_fake_vector)
160
+
161
+ orig_fnames.append(orig_img_fname)
162
+ mask_fnames.append(mask_fname)
163
+
164
+ ids_features = np.concatenate(real_features + fake_features, axis=0)
165
+ ids_labels = np.array(([1] * len(real_features)) + ([0] * len(fake_features)))
166
+
167
+ with open(os.path.join(latents_dir, 'featues.pkl'), 'wb') as f:
168
+ pickle.dump(ids_features, f, protocol=3)
169
+ with open(os.path.join(latents_dir, 'labels.pkl'), 'wb') as f:
170
+ pickle.dump(ids_labels, f, protocol=3)
171
+ with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'wb') as f:
172
+ pickle.dump(orig_fnames, f, protocol=3)
173
+ with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'wb') as f:
174
+ pickle.dump(mask_fnames, f, protocol=3)
175
+ with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'wb') as f:
176
+ pickle.dump(mask2real_fname, f, protocol=3)
177
+ with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'wb') as f:
178
+ pickle.dump(mask2fake_fname, f, protocol=3)
179
+
180
+ svm = sklearn.svm.LinearSVC(dual=False)
181
+ svm.fit(ids_features, ids_labels)
182
+
183
+ pred_scores = svm.decision_function(ids_features)
184
+ real_scores = pred_scores[:len(real_features)]
185
+ fake_scores = pred_scores[len(real_features):]
186
+
187
+ with open(os.path.join(latents_dir, 'pred_scores.pkl'), 'wb') as f:
188
+ pickle.dump(pred_scores, f, protocol=3)
189
+ with open(os.path.join(latents_dir, 'real_scores.pkl'), 'wb') as f:
190
+ pickle.dump(real_scores, f, protocol=3)
191
+ with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'wb') as f:
192
+ pickle.dump(fake_scores, f, protocol=3)
193
+ else:
194
+ with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'rb') as f:
195
+ orig_fnames = pickle.load(f)
196
+ with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'rb') as f:
197
+ mask_fnames = pickle.load(f)
198
+ with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'rb') as f:
199
+ mask2real_fname = pickle.load(f)
200
+ with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'rb') as f:
201
+ mask2fake_fname = pickle.load(f)
202
+ with open(os.path.join(latents_dir, 'real_scores.pkl'), 'rb') as f:
203
+ real_scores = pickle.load(f)
204
+ with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'rb') as f:
205
+ fake_scores = pickle.load(f)
206
+
207
+ real_info = pd.DataFrame(data=[dict(real_fname=fname,
208
+ real_score=score)
209
+ for fname, score
210
+ in zip(orig_fnames, real_scores)])
211
+ real_info.set_index('real_fname', drop=True, inplace=True)
212
+
213
+ fake_info = pd.DataFrame(data=[dict(mask_fname=fname,
214
+ fake_fname=mask2fake_fname[fname],
215
+ real_fname=mask2real_fname[fname],
216
+ fake_score=score)
217
+ for fname, score
218
+ in zip(mask_fnames, fake_scores)])
219
+ fake_info = fake_info.join(real_info, on='real_fname', how='left')
220
+ fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True)
221
+
222
+ fake_stats_by_real = fake_info.groupby('real_fname')['fake_score'].describe()[['mean', 'std']].rename(
223
+ {'mean': 'mean_fake_by_real', 'std': 'std_fake_by_real'}, axis=1)
224
+ fake_info = fake_info.join(fake_stats_by_real, on='real_fname', rsuffix='stat_by_real')
225
+ fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True)
226
+ fake_info.to_csv(os.path.join(latents_dir, 'join_scores_table.csv'), sep='\t', index=False)
227
+
228
+ fake_scores_table = fake_info.set_index('mask_fname')['fake_score'].to_frame()
229
+ real_scores_table = fake_info.set_index('real_fname')['real_score'].drop_duplicates().to_frame()
230
+
231
+ fig, (ax1, ax2) = plt.subplots(1, 2)
232
+ ax1.hist(fake_scores)
233
+ ax2.hist(real_scores)
234
+ fig.tight_layout()
235
+ fig.savefig(os.path.join(args.outpath, 'global_scores_hist.png'))
236
+ plt.close(fig)
237
+
238
+ global_worst_masks = fake_info.sort_values('fake_score', ascending=True)['mask_fname'].iloc[:config.take_global_top].to_list()
239
+ global_best_masks = fake_info.sort_values('fake_score', ascending=False)['mask_fname'].iloc[:config.take_global_top].to_list()
240
+ save_global_samples(global_worst_masks, mask2real_fname, mask2fake_fname, global_worst_dir, real_scores_table, fake_scores_table)
241
+ save_global_samples(global_best_masks, mask2real_fname, mask2fake_fname, global_best_dir, real_scores_table, fake_scores_table)
242
+
243
+ # grouped by real
244
+ worst_samples_by_real = fake_info.groupby('real_fname').apply(
245
+ lambda d: d.set_index('mask_fname')['fake_score'].idxmin()).to_frame().rename({0: 'worst'}, axis=1)
246
+ best_samples_by_real = fake_info.groupby('real_fname').apply(
247
+ lambda d: d.set_index('mask_fname')['fake_score'].idxmax()).to_frame().rename({0: 'best'}, axis=1)
248
+ worst_best_by_real = pd.concat([worst_samples_by_real, best_samples_by_real], axis=1)
249
+
250
+ worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'worst_score'}, axis=1),
251
+ on='worst')
252
+ worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'best_score'}, axis=1),
253
+ on='best')
254
+ worst_best_by_real = worst_best_by_real.join(real_scores_table)
255
+
256
+ worst_best_by_real['best_worst_score_diff'] = worst_best_by_real['best_score'] - worst_best_by_real['worst_score']
257
+ worst_best_by_real['real_best_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['best_score']
258
+ worst_best_by_real['real_worst_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['worst_score']
259
+
260
+ worst_best_by_best_worst_score_diff_min = worst_best_by_real.sort_values('best_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top]
261
+ worst_best_by_best_worst_score_diff_max = worst_best_by_real.sort_values('best_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top]
262
+ save_samples_by_real(worst_best_by_best_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_min_dir)
263
+ save_samples_by_real(worst_best_by_best_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_max_dir)
264
+
265
+ worst_best_by_real_best_score_diff_min = worst_best_by_real.sort_values('real_best_score_diff', ascending=True).iloc[:config.take_worst_best_top]
266
+ worst_best_by_real_best_score_diff_max = worst_best_by_real.sort_values('real_best_score_diff', ascending=False).iloc[:config.take_worst_best_top]
267
+ save_samples_by_real(worst_best_by_real_best_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_min_dir)
268
+ save_samples_by_real(worst_best_by_real_best_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_max_dir)
269
+
270
+ worst_best_by_real_worst_score_diff_min = worst_best_by_real.sort_values('real_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top]
271
+ worst_best_by_real_worst_score_diff_max = worst_best_by_real.sort_values('real_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top]
272
+ save_samples_by_real(worst_best_by_real_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_min_dir)
273
+ save_samples_by_real(worst_best_by_real_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_max_dir)
274
+
275
+ # analyze what change of mask causes bigger change of score
276
+ overlapping_mask_fname_pairs = []
277
+ overlapping_mask_fname_score_diffs = []
278
+ for cur_real_fname in orig_fnames:
279
+ cur_fakes_info = fake_info[fake_info['real_fname'] == cur_real_fname]
280
+ cur_mask_fnames = sorted(cur_fakes_info['mask_fname'].unique())
281
+
282
+ cur_mask_pairs_and_scores = Parallel(args.n_jobs)(
283
+ delayed(extract_overlapping_masks)(cur_mask_fnames, i, fake_scores_table)
284
+ for i in range(len(cur_mask_fnames) - 1)
285
+ )
286
+ for cur_pairs, cur_scores in cur_mask_pairs_and_scores:
287
+ overlapping_mask_fname_pairs.extend(cur_pairs)
288
+ overlapping_mask_fname_score_diffs.extend(cur_scores)
289
+
290
+ overlapping_mask_fname_pairs = np.asarray(overlapping_mask_fname_pairs)
291
+ overlapping_mask_fname_score_diffs = np.asarray(overlapping_mask_fname_score_diffs)
292
+ overlapping_sort_idx = np.argsort(overlapping_mask_fname_score_diffs)
293
+ overlapping_mask_fname_pairs = overlapping_mask_fname_pairs[overlapping_sort_idx]
294
+ overlapping_mask_fname_score_diffs = overlapping_mask_fname_score_diffs[overlapping_sort_idx]
295
+
296
+
297
+
298
+
299
+
300
+
301
+ if __name__ == '__main__':
302
+ import argparse
303
+
304
+ aparser = argparse.ArgumentParser()
305
+ aparser.add_argument('config', type=str, help='Path to config for dataset generation')
306
+ aparser.add_argument('datadir', type=str,
307
+ help='Path to folder with images and masks (output of gen_mask_dataset.py)')
308
+ aparser.add_argument('predictdir', type=str,
309
+ help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
310
+ aparser.add_argument('outpath', type=str, help='Where to put results')
311
+ aparser.add_argument('--only-report', action='store_true',
312
+ help='Whether to skip prediction and feature extraction, '
313
+ 'load all the possible latents and proceed with report only')
314
+ aparser.add_argument('--n-jobs', type=int, default=8, help='how many processes to use for pair mask mining')
315
+
316
+ main(aparser.parse_args())
bin/blur_predicts.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import tqdm
8
+
9
+ from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
10
+ from saicinpainting.evaluation.utils import load_yaml
11
+
12
+
13
+ def main(args):
14
+ config = load_yaml(args.config)
15
+
16
+ if not args.predictdir.endswith('/'):
17
+ args.predictdir += '/'
18
+
19
+ dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
20
+
21
+ os.makedirs(os.path.dirname(args.outpath), exist_ok=True)
22
+
23
+ for img_i in tqdm.trange(len(dataset)):
24
+ pred_fname = dataset.pred_filenames[img_i]
25
+ cur_out_fname = os.path.join(args.outpath, pred_fname[len(args.predictdir):])
26
+ os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
27
+
28
+ sample = dataset[img_i]
29
+ img = sample['image']
30
+ mask = sample['mask']
31
+ inpainted = sample['inpainted']
32
+
33
+ inpainted_blurred = cv2.GaussianBlur(np.transpose(inpainted, (1, 2, 0)),
34
+ ksize=(args.k, args.k),
35
+ sigmaX=args.s, sigmaY=args.s,
36
+ borderType=cv2.BORDER_REFLECT)
37
+
38
+ cur_res = (1 - mask) * np.transpose(img, (1, 2, 0)) + mask * inpainted_blurred
39
+ cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
40
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
41
+ cv2.imwrite(cur_out_fname, cur_res)
42
+
43
+
44
+ if __name__ == '__main__':
45
+ import argparse
46
+
47
+ aparser = argparse.ArgumentParser()
48
+ aparser.add_argument('config', type=str, help='Path to evaluation config')
49
+ aparser.add_argument('datadir', type=str,
50
+ help='Path to folder with images and masks (output of gen_mask_dataset.py)')
51
+ aparser.add_argument('predictdir', type=str,
52
+ help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
53
+ aparser.add_argument('outpath', type=str, help='Where to put results')
54
+ aparser.add_argument('-s', type=float, default=0.1, help='Gaussian blur sigma')
55
+ aparser.add_argument('-k', type=int, default=5, help='Kernel size in gaussian blur')
56
+
57
+ main(aparser.parse_args())
bin/calc_dataset_stats.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+
5
+ import numpy as np
6
+ import tqdm
7
+ from scipy.ndimage.morphology import distance_transform_edt
8
+
9
+ from saicinpainting.evaluation.data import InpaintingDataset
10
+ from saicinpainting.evaluation.vis import save_item_for_vis
11
+
12
+
13
+ def main(args):
14
+ dataset = InpaintingDataset(args.datadir, img_suffix='.png')
15
+
16
+ area_bins = np.linspace(0, 1, args.area_bins + 1)
17
+
18
+ heights = []
19
+ widths = []
20
+ image_areas = []
21
+ hole_areas = []
22
+ hole_area_percents = []
23
+ known_pixel_distances = []
24
+
25
+ area_bins_count = np.zeros(args.area_bins)
26
+ area_bin_titles = [f'{area_bins[i] * 100:.0f}-{area_bins[i + 1] * 100:.0f}' for i in range(args.area_bins)]
27
+
28
+ bin2i = [[] for _ in range(args.area_bins)]
29
+
30
+ for i, item in enumerate(tqdm.tqdm(dataset)):
31
+ h, w = item['image'].shape[1:]
32
+ heights.append(h)
33
+ widths.append(w)
34
+ full_area = h * w
35
+ image_areas.append(full_area)
36
+ bin_mask = item['mask'] > 0.5
37
+ hole_area = bin_mask.sum()
38
+ hole_areas.append(hole_area)
39
+ hole_percent = hole_area / full_area
40
+ hole_area_percents.append(hole_percent)
41
+ bin_i = np.clip(np.searchsorted(area_bins, hole_percent) - 1, 0, len(area_bins_count) - 1)
42
+ area_bins_count[bin_i] += 1
43
+ bin2i[bin_i].append(i)
44
+
45
+ cur_dist = distance_transform_edt(bin_mask)
46
+ cur_dist_inside_mask = cur_dist[bin_mask]
47
+ known_pixel_distances.append(cur_dist_inside_mask.mean())
48
+
49
+ os.makedirs(args.outdir, exist_ok=True)
50
+ with open(os.path.join(args.outdir, 'summary.txt'), 'w') as f:
51
+ f.write(f'''Location: {args.datadir}
52
+
53
+ Number of samples: {len(dataset)}
54
+
55
+ Image height: min {min(heights):5d} max {max(heights):5d} mean {np.mean(heights):.2f}
56
+ Image width: min {min(widths):5d} max {max(widths):5d} mean {np.mean(widths):.2f}
57
+ Image area: min {min(image_areas):7d} max {max(image_areas):7d} mean {np.mean(image_areas):.2f}
58
+ Hole area: min {min(hole_areas):7d} max {max(hole_areas):7d} mean {np.mean(hole_areas):.2f}
59
+ Hole area %: min {min(hole_area_percents) * 100:2.2f} max {max(hole_area_percents) * 100:2.2f} mean {np.mean(hole_area_percents) * 100:2.2f}
60
+ Dist 2known: min {min(known_pixel_distances):2.2f} max {max(known_pixel_distances):2.2f} mean {np.mean(known_pixel_distances):2.2f} median {np.median(known_pixel_distances):2.2f}
61
+
62
+ Stats by hole area %:
63
+ ''')
64
+ for bin_i in range(args.area_bins):
65
+ f.write(f'{area_bin_titles[bin_i]}%: '
66
+ f'samples number {area_bins_count[bin_i]}, '
67
+ f'{area_bins_count[bin_i] / len(dataset) * 100:.1f}%\n')
68
+
69
+ for bin_i in range(args.area_bins):
70
+ bindir = os.path.join(args.outdir, 'samples', area_bin_titles[bin_i])
71
+ os.makedirs(bindir, exist_ok=True)
72
+ bin_idx = bin2i[bin_i]
73
+ for sample_i in np.random.choice(bin_idx, size=min(len(bin_idx), args.samples_n), replace=False):
74
+ save_item_for_vis(dataset[sample_i], os.path.join(bindir, f'{sample_i}.png'))
75
+
76
+
77
+ if __name__ == '__main__':
78
+ import argparse
79
+
80
+ aparser = argparse.ArgumentParser()
81
+ aparser.add_argument('datadir', type=str,
82
+ help='Path to folder with images and masks (output of gen_mask_dataset.py)')
83
+ aparser.add_argument('outdir', type=str, help='Where to put results')
84
+ aparser.add_argument('--samples-n', type=int, default=10,
85
+ help='Number of sample images with masks to copy for visualization for each area bin')
86
+ aparser.add_argument('--area-bins', type=int, default=10, help='How many area bins to have')
87
+
88
+ main(aparser.parse_args())
bin/debug/analyze_overlapping_masks.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ BASEDIR="$(dirname $0)"
4
+
5
+ # paths are valid for mml7
6
+
7
+ # select images
8
+ #ls /data/inpainting/work/data/train | shuf | head -2000 | xargs -n1 -I{} cp {} /data/inpainting/mask_analysis/src
9
+
10
+ # generate masks
11
+ #"$BASEDIR/../gen_debug_mask_dataset.py" \
12
+ # "$BASEDIR/../../configs/debug_mask_gen.yaml" \
13
+ # "/data/inpainting/mask_analysis/src" \
14
+ # "/data/inpainting/mask_analysis/generated"
15
+
16
+ # predict
17
+ #"$BASEDIR/../predict.py" \
18
+ # model.path="simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/saved_checkpoint/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15_epoch22-step-574999" \
19
+ # indir="/data/inpainting/mask_analysis/generated" \
20
+ # outdir="/data/inpainting/mask_analysis/predicted" \
21
+ # dataset.img_suffix=.jpg \
22
+ # +out_ext=.jpg
23
+
24
+ # analyze good and bad samples
25
+ "$BASEDIR/../analyze_errors.py" \
26
+ --only-report \
27
+ --n-jobs 8 \
28
+ "$BASEDIR/../../configs/analyze_mask_errors.yaml" \
29
+ "/data/inpainting/mask_analysis/small/generated" \
30
+ "/data/inpainting/mask_analysis/small/predicted" \
31
+ "/data/inpainting/mask_analysis/small/report"
bin/evaluate_predicts.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+
5
+ import pandas as pd
6
+
7
+ from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
8
+ from saicinpainting.evaluation.evaluator import InpaintingEvaluator, lpips_fid100_f1
9
+ from saicinpainting.evaluation.losses.base_loss import SegmentationAwareSSIM, \
10
+ SegmentationClassStats, SSIMScore, LPIPSScore, FIDScore, SegmentationAwareLPIPS, SegmentationAwareFID
11
+ from saicinpainting.evaluation.utils import load_yaml
12
+
13
+
14
+ def main(args):
15
+ config = load_yaml(args.config)
16
+
17
+ dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
18
+
19
+ metrics = {
20
+ 'ssim': SSIMScore(),
21
+ 'lpips': LPIPSScore(),
22
+ 'fid': FIDScore()
23
+ }
24
+ enable_segm = config.get('segmentation', dict(enable=False)).get('enable', False)
25
+ if enable_segm:
26
+ weights_path = os.path.expandvars(config.segmentation.weights_path)
27
+ metrics.update(dict(
28
+ segm_stats=SegmentationClassStats(weights_path=weights_path),
29
+ segm_ssim=SegmentationAwareSSIM(weights_path=weights_path),
30
+ segm_lpips=SegmentationAwareLPIPS(weights_path=weights_path),
31
+ segm_fid=SegmentationAwareFID(weights_path=weights_path)
32
+ ))
33
+ evaluator = InpaintingEvaluator(dataset, scores=metrics,
34
+ integral_title='lpips_fid100_f1', integral_func=lpips_fid100_f1,
35
+ **config.evaluator_kwargs)
36
+
37
+ os.makedirs(os.path.dirname(args.outpath), exist_ok=True)
38
+
39
+ results = evaluator.evaluate()
40
+
41
+ results = pd.DataFrame(results).stack(1).unstack(0)
42
+ results.dropna(axis=1, how='all', inplace=True)
43
+ results.to_csv(args.outpath, sep='\t', float_format='%.4f')
44
+
45
+ if enable_segm:
46
+ only_short_results = results[[c for c in results.columns if not c[0].startswith('segm_')]].dropna(axis=1, how='all')
47
+ only_short_results.to_csv(args.outpath + '_short', sep='\t', float_format='%.4f')
48
+
49
+ print(only_short_results)
50
+
51
+ segm_metrics_results = results[['segm_ssim', 'segm_lpips', 'segm_fid']].dropna(axis=1, how='all').transpose().unstack(0).reorder_levels([1, 0], axis=1)
52
+ segm_metrics_results.drop(['mean', 'std'], axis=0, inplace=True)
53
+
54
+ segm_stats_results = results['segm_stats'].dropna(axis=1, how='all').transpose()
55
+ segm_stats_results.index = pd.MultiIndex.from_tuples(n.split('/') for n in segm_stats_results.index)
56
+ segm_stats_results = segm_stats_results.unstack(0).reorder_levels([1, 0], axis=1)
57
+ segm_stats_results.sort_index(axis=1, inplace=True)
58
+ segm_stats_results.dropna(axis=0, how='all', inplace=True)
59
+
60
+ segm_results = pd.concat([segm_metrics_results, segm_stats_results], axis=1, sort=True)
61
+ segm_results.sort_values(('mask_freq', 'total'), ascending=False, inplace=True)
62
+
63
+ segm_results.to_csv(args.outpath + '_segm', sep='\t', float_format='%.4f')
64
+ else:
65
+ print(results)
66
+
67
+
68
+ if __name__ == '__main__':
69
+ import argparse
70
+
71
+ aparser = argparse.ArgumentParser()
72
+ aparser.add_argument('config', type=str, help='Path to evaluation config')
73
+ aparser.add_argument('datadir', type=str,
74
+ help='Path to folder with images and masks (output of gen_mask_dataset.py)')
75
+ aparser.add_argument('predictdir', type=str,
76
+ help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
77
+ aparser.add_argument('outpath', type=str, help='Where to put results')
78
+
79
+ main(aparser.parse_args())
bin/evaluator_example.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from skimage import io
7
+ from skimage.transform import resize
8
+ from torch.utils.data import Dataset
9
+
10
+ from saicinpainting.evaluation.evaluator import InpaintingEvaluator
11
+ from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
12
+
13
+
14
+ class SimpleImageDataset(Dataset):
15
+ def __init__(self, root_dir, image_size=(400, 600)):
16
+ self.root_dir = root_dir
17
+ self.files = sorted(os.listdir(root_dir))
18
+ self.image_size = image_size
19
+
20
+ def __getitem__(self, index):
21
+ img_name = os.path.join(self.root_dir, self.files[index])
22
+ image = io.imread(img_name)
23
+ image = resize(image, self.image_size, anti_aliasing=True)
24
+ image = torch.FloatTensor(image).permute(2, 0, 1)
25
+ return image
26
+
27
+ def __len__(self):
28
+ return len(self.files)
29
+
30
+
31
+ def create_rectangle_mask(height, width):
32
+ mask = np.ones((height, width))
33
+ up_left_corner = width // 4, height // 4
34
+ down_right_corner = (width - up_left_corner[0] - 1, height - up_left_corner[1] - 1)
35
+ cv2.rectangle(mask, up_left_corner, down_right_corner, (0, 0, 0), thickness=cv2.FILLED)
36
+ return mask
37
+
38
+
39
+ class Model():
40
+ def __call__(self, img_batch, mask_batch):
41
+ mean = (img_batch * mask_batch[:, None, :, :]).sum(dim=(2, 3)) / mask_batch.sum(dim=(1, 2))[:, None]
42
+ inpainted = mean[:, :, None, None] * (1 - mask_batch[:, None, :, :]) + img_batch * mask_batch[:, None, :, :]
43
+ return inpainted
44
+
45
+
46
+ class SimpleImageSquareMaskDataset(Dataset):
47
+ def __init__(self, dataset):
48
+ self.dataset = dataset
49
+ self.mask = torch.FloatTensor(create_rectangle_mask(*self.dataset.image_size))
50
+ self.model = Model()
51
+
52
+ def __getitem__(self, index):
53
+ img = self.dataset[index]
54
+ mask = self.mask.clone()
55
+ inpainted = self.model(img[None, ...], mask[None, ...])
56
+ return dict(image=img, mask=mask, inpainted=inpainted)
57
+
58
+ def __len__(self):
59
+ return len(self.dataset)
60
+
61
+
62
+ dataset = SimpleImageDataset('imgs')
63
+ mask_dataset = SimpleImageSquareMaskDataset(dataset)
64
+ model = Model()
65
+ metrics = {
66
+ 'ssim': SSIMScore(),
67
+ 'lpips': LPIPSScore(),
68
+ 'fid': FIDScore()
69
+ }
70
+
71
+ evaluator = InpaintingEvaluator(
72
+ mask_dataset, scores=metrics, batch_size=3, area_grouping=True
73
+ )
74
+
75
+ results = evaluator.evaluate(model)
76
+ print(results)
bin/extract_masks.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL.Image as Image
2
+ import numpy as np
3
+ import os
4
+
5
+
6
+ def main(args):
7
+ if not args.indir.endswith('/'):
8
+ args.indir += '/'
9
+ os.makedirs(args.outdir, exist_ok=True)
10
+
11
+ src_images = [
12
+ args.indir+fname for fname in os.listdir(args.indir)]
13
+
14
+ tgt_masks = [
15
+ args.outdir+fname[:-4] + f'_mask000.png'
16
+ for fname in os.listdir(args.indir)]
17
+
18
+ for img_name, msk_name in zip(src_images, tgt_masks):
19
+ #print(img)
20
+ #print(msk)
21
+
22
+ image = Image.open(img_name).convert('RGB')
23
+ image = np.transpose(np.array(image), (2, 0, 1))
24
+
25
+ mask = (image == 255).astype(int)
26
+
27
+ print(mask.dtype, mask.shape)
28
+
29
+
30
+ Image.fromarray(
31
+ np.clip(mask[0,:,:] * 255, 0, 255).astype('uint8'),mode='L'
32
+ ).save(msk_name)
33
+
34
+
35
+
36
+
37
+ '''
38
+ for infile in src_images:
39
+ try:
40
+ file_relpath = infile[len(indir):]
41
+ img_outpath = os.path.join(outdir, file_relpath)
42
+ os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
43
+
44
+ image = Image.open(infile).convert('RGB')
45
+
46
+ mask =
47
+
48
+ Image.fromarray(
49
+ np.clip(
50
+ cur_mask * 255, 0, 255).astype('uint8'),
51
+ mode='L'
52
+ ).save(cur_basename + f'_mask{i:03d}.png')
53
+ '''
54
+
55
+
56
+
57
+ if __name__ == '__main__':
58
+ import argparse
59
+ aparser = argparse.ArgumentParser()
60
+ aparser.add_argument('--indir', type=str, help='Path to folder with images')
61
+ aparser.add_argument('--outdir', type=str, help='Path to folder to store aligned images and masks to')
62
+
63
+ main(aparser.parse_args())
bin/filter_sharded_dataset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+
4
+ import math
5
+ import os
6
+ import random
7
+
8
+ import braceexpand
9
+ import webdataset as wds
10
+
11
+ DEFAULT_CATS_FILE = os.path.join(os.path.dirname(__file__), '..', 'configs', 'places2-categories_157.txt')
12
+
13
+ def is_good_key(key, cats):
14
+ return any(c in key for c in cats)
15
+
16
+
17
+ def main(args):
18
+ if args.categories == 'nofilter':
19
+ good_categories = None
20
+ else:
21
+ with open(args.categories, 'r') as f:
22
+ good_categories = set(line.strip().split(' ')[0] for line in f if line.strip())
23
+
24
+ all_input_files = list(braceexpand.braceexpand(args.infile))
25
+ chunk_size = int(math.ceil(len(all_input_files) / args.n_read_streams))
26
+
27
+ input_iterators = [iter(wds.Dataset(all_input_files[start : start + chunk_size]).shuffle(args.shuffle_buffer))
28
+ for start in range(0, len(all_input_files), chunk_size)]
29
+ output_datasets = [wds.ShardWriter(args.outpattern.format(i)) for i in range(args.n_write_streams)]
30
+
31
+ good_readers = list(range(len(input_iterators)))
32
+ step_i = 0
33
+ good_samples = 0
34
+ bad_samples = 0
35
+ while len(good_readers) > 0:
36
+ if step_i % args.print_freq == 0:
37
+ print(f'Iterations done {step_i}; readers alive {good_readers}; good samples {good_samples}; bad samples {bad_samples}')
38
+
39
+ step_i += 1
40
+
41
+ ri = random.choice(good_readers)
42
+ try:
43
+ sample = next(input_iterators[ri])
44
+ except StopIteration:
45
+ good_readers = list(set(good_readers) - {ri})
46
+ continue
47
+
48
+ if good_categories is not None and not is_good_key(sample['__key__'], good_categories):
49
+ bad_samples += 1
50
+ continue
51
+
52
+ wi = random.randint(0, args.n_write_streams - 1)
53
+ output_datasets[wi].write(sample)
54
+ good_samples += 1
55
+
56
+
57
+ if __name__ == '__main__':
58
+ import argparse
59
+
60
+ aparser = argparse.ArgumentParser()
61
+ aparser.add_argument('--categories', type=str, default=DEFAULT_CATS_FILE)
62
+ aparser.add_argument('--shuffle-buffer', type=int, default=10000)
63
+ aparser.add_argument('--n-read-streams', type=int, default=10)
64
+ aparser.add_argument('--n-write-streams', type=int, default=10)
65
+ aparser.add_argument('--print-freq', type=int, default=1000)
66
+ aparser.add_argument('infile', type=str)
67
+ aparser.add_argument('outpattern', type=str)
68
+
69
+ main(aparser.parse_args())
bin/gen_debug_mask_dataset.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import glob
4
+ import os
5
+
6
+ import PIL.Image as Image
7
+ import cv2
8
+ import numpy as np
9
+ import tqdm
10
+ import shutil
11
+
12
+
13
+ from saicinpainting.evaluation.utils import load_yaml
14
+
15
+
16
+ def generate_masks_for_img(infile, outmask_pattern, mask_size=200, step=0.5):
17
+ inimg = Image.open(infile)
18
+ width, height = inimg.size
19
+ step_abs = int(mask_size * step)
20
+
21
+ mask = np.zeros((height, width), dtype='uint8')
22
+ mask_i = 0
23
+
24
+ for start_vertical in range(0, height - step_abs, step_abs):
25
+ for start_horizontal in range(0, width - step_abs, step_abs):
26
+ mask[start_vertical:start_vertical + mask_size, start_horizontal:start_horizontal + mask_size] = 255
27
+
28
+ cv2.imwrite(outmask_pattern.format(mask_i), mask)
29
+
30
+ mask[start_vertical:start_vertical + mask_size, start_horizontal:start_horizontal + mask_size] = 0
31
+ mask_i += 1
32
+
33
+
34
+ def main(args):
35
+ if not args.indir.endswith('/'):
36
+ args.indir += '/'
37
+ if not args.outdir.endswith('/'):
38
+ args.outdir += '/'
39
+
40
+ config = load_yaml(args.config)
41
+
42
+ in_files = list(glob.glob(os.path.join(args.indir, '**', f'*{config.img_ext}'), recursive=True))
43
+ for infile in tqdm.tqdm(in_files):
44
+ outimg = args.outdir + infile[len(args.indir):]
45
+ outmask_pattern = outimg[:-len(config.img_ext)] + '_mask{:04d}.png'
46
+
47
+ os.makedirs(os.path.dirname(outimg), exist_ok=True)
48
+ shutil.copy2(infile, outimg)
49
+
50
+ generate_masks_for_img(infile, outmask_pattern, **config.gen_kwargs)
51
+
52
+
53
+ if __name__ == '__main__':
54
+ import argparse
55
+
56
+ aparser = argparse.ArgumentParser()
57
+ aparser.add_argument('config', type=str, help='Path to config for dataset generation')
58
+ aparser.add_argument('indir', type=str, help='Path to folder with images')
59
+ aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
60
+
61
+ main(aparser.parse_args())
bin/gen_mask_dataset.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import glob
4
+ import os
5
+ import shutil
6
+ import traceback
7
+
8
+ import PIL.Image as Image
9
+ import numpy as np
10
+ from joblib import Parallel, delayed
11
+
12
+ from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop
13
+ from saicinpainting.evaluation.utils import load_yaml, SmallMode
14
+ from saicinpainting.training.data.masks import MixedMaskGenerator
15
+
16
+
17
+ class MakeManyMasksWrapper:
18
+ def __init__(self, impl, variants_n=2):
19
+ self.impl = impl
20
+ self.variants_n = variants_n
21
+
22
+ def get_masks(self, img):
23
+ img = np.transpose(np.array(img), (2, 0, 1))
24
+ return [self.impl(img)[0] for _ in range(self.variants_n)]
25
+
26
+
27
+ def process_images(src_images, indir, outdir, config):
28
+ if config.generator_kind == 'segmentation':
29
+ mask_generator = SegmentationMask(**config.mask_generator_kwargs)
30
+ elif config.generator_kind == 'random':
31
+ variants_n = config.mask_generator_kwargs.pop('variants_n', 2)
32
+ mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**config.mask_generator_kwargs),
33
+ variants_n=variants_n)
34
+ else:
35
+ raise ValueError(f'Unexpected generator kind: {config.generator_kind}')
36
+
37
+ max_tamper_area = config.get('max_tamper_area', 1)
38
+
39
+ for infile in src_images:
40
+ try:
41
+ file_relpath = infile[len(indir):]
42
+ img_outpath = os.path.join(outdir, file_relpath)
43
+ os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
44
+
45
+ image = Image.open(infile).convert('RGB')
46
+
47
+ # scale input image to output resolution and filter smaller images
48
+ if min(image.size) < config.cropping.out_min_size:
49
+ handle_small_mode = SmallMode(config.cropping.handle_small_mode)
50
+ if handle_small_mode == SmallMode.DROP:
51
+ continue
52
+ elif handle_small_mode == SmallMode.UPSCALE:
53
+ factor = config.cropping.out_min_size / min(image.size)
54
+ out_size = (np.array(image.size) * factor).round().astype('uint32')
55
+ image = image.resize(out_size, resample=Image.BICUBIC)
56
+ else:
57
+ factor = config.cropping.out_min_size / min(image.size)
58
+ out_size = (np.array(image.size) * factor).round().astype('uint32')
59
+ image = image.resize(out_size, resample=Image.BICUBIC)
60
+
61
+ # generate and select masks
62
+ src_masks = mask_generator.get_masks(image)
63
+
64
+ filtered_image_mask_pairs = []
65
+ for cur_mask in src_masks:
66
+ if config.cropping.out_square_crop:
67
+ (crop_left,
68
+ crop_top,
69
+ crop_right,
70
+ crop_bottom) = propose_random_square_crop(cur_mask,
71
+ min_overlap=config.cropping.crop_min_overlap)
72
+ cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
73
+ cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
74
+ else:
75
+ cur_image = image
76
+
77
+ if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
78
+ continue
79
+
80
+ filtered_image_mask_pairs.append((cur_image, cur_mask))
81
+
82
+ mask_indices = np.random.choice(len(filtered_image_mask_pairs),
83
+ size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
84
+ replace=False)
85
+
86
+ # crop masks; save masks together with input image
87
+ mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
88
+ for i, idx in enumerate(mask_indices):
89
+ cur_image, cur_mask = filtered_image_mask_pairs[idx]
90
+ cur_basename = mask_basename + f'_crop{i:03d}'
91
+ Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
92
+ mode='L').save(cur_basename + f'_mask{i:03d}.png')
93
+ cur_image.save(cur_basename + '.png')
94
+ except KeyboardInterrupt:
95
+ return
96
+ except Exception as ex:
97
+ print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')
98
+
99
+
100
+ def main(args):
101
+ if not args.indir.endswith('/'):
102
+ args.indir += '/'
103
+
104
+ os.makedirs(args.outdir, exist_ok=True)
105
+
106
+ config = load_yaml(args.config)
107
+
108
+ in_files = list(glob.glob(os.path.join(args.indir, '**', f'*.{args.ext}'), recursive=True))
109
+ if args.n_jobs == 0:
110
+ process_images(in_files, args.indir, args.outdir, config)
111
+ else:
112
+ in_files_n = len(in_files)
113
+ chunk_size = in_files_n // args.n_jobs + (1 if in_files_n % args.n_jobs > 0 else 0)
114
+ Parallel(n_jobs=args.n_jobs)(
115
+ delayed(process_images)(in_files[start:start+chunk_size], args.indir, args.outdir, config)
116
+ for start in range(0, len(in_files), chunk_size)
117
+ )
118
+
119
+
120
+ if __name__ == '__main__':
121
+ import argparse
122
+
123
+ aparser = argparse.ArgumentParser()
124
+ aparser.add_argument('config', type=str, help='Path to config for dataset generation')
125
+ aparser.add_argument('indir', type=str, help='Path to folder with images')
126
+ aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
127
+ aparser.add_argument('--n-jobs', type=int, default=0, help='How many processes to use')
128
+ aparser.add_argument('--ext', type=str, default='jpg', help='Input image extension')
129
+
130
+ main(aparser.parse_args())
bin/gen_mask_dataset_hydra.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import glob
4
+ import os
5
+ import shutil
6
+ import traceback
7
+ import hydra
8
+ from omegaconf import OmegaConf
9
+
10
+ import PIL.Image as Image
11
+ import numpy as np
12
+ from joblib import Parallel, delayed
13
+
14
+ from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop
15
+ from saicinpainting.evaluation.utils import load_yaml, SmallMode
16
+ from saicinpainting.training.data.masks import MixedMaskGenerator
17
+
18
+
19
+ class MakeManyMasksWrapper:
20
+ def __init__(self, impl, variants_n=2):
21
+ self.impl = impl
22
+ self.variants_n = variants_n
23
+
24
+ def get_masks(self, img):
25
+ img = np.transpose(np.array(img), (2, 0, 1))
26
+ return [self.impl(img)[0] for _ in range(self.variants_n)]
27
+
28
+
29
+ def process_images(src_images, indir, outdir, config):
30
+ if config.generator_kind == 'segmentation':
31
+ mask_generator = SegmentationMask(**config.mask_generator_kwargs)
32
+ elif config.generator_kind == 'random':
33
+ mask_generator_kwargs = OmegaConf.to_container(config.mask_generator_kwargs, resolve=True)
34
+ variants_n = mask_generator_kwargs.pop('variants_n', 2)
35
+ mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**mask_generator_kwargs),
36
+ variants_n=variants_n)
37
+ else:
38
+ raise ValueError(f'Unexpected generator kind: {config.generator_kind}')
39
+
40
+ max_tamper_area = config.get('max_tamper_area', 1)
41
+
42
+ for infile in src_images:
43
+ try:
44
+ file_relpath = infile[len(indir):]
45
+ img_outpath = os.path.join(outdir, file_relpath)
46
+ os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
47
+
48
+ image = Image.open(infile).convert('RGB')
49
+
50
+ # scale input image to output resolution and filter smaller images
51
+ if min(image.size) < config.cropping.out_min_size:
52
+ handle_small_mode = SmallMode(config.cropping.handle_small_mode)
53
+ if handle_small_mode == SmallMode.DROP:
54
+ continue
55
+ elif handle_small_mode == SmallMode.UPSCALE:
56
+ factor = config.cropping.out_min_size / min(image.size)
57
+ out_size = (np.array(image.size) * factor).round().astype('uint32')
58
+ image = image.resize(out_size, resample=Image.BICUBIC)
59
+ else:
60
+ factor = config.cropping.out_min_size / min(image.size)
61
+ out_size = (np.array(image.size) * factor).round().astype('uint32')
62
+ image = image.resize(out_size, resample=Image.BICUBIC)
63
+
64
+ # generate and select masks
65
+ src_masks = mask_generator.get_masks(image)
66
+
67
+ filtered_image_mask_pairs = []
68
+ for cur_mask in src_masks:
69
+ if config.cropping.out_square_crop:
70
+ (crop_left,
71
+ crop_top,
72
+ crop_right,
73
+ crop_bottom) = propose_random_square_crop(cur_mask,
74
+ min_overlap=config.cropping.crop_min_overlap)
75
+ cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
76
+ cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
77
+ else:
78
+ cur_image = image
79
+
80
+ if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
81
+ continue
82
+
83
+ filtered_image_mask_pairs.append((cur_image, cur_mask))
84
+
85
+ mask_indices = np.random.choice(len(filtered_image_mask_pairs),
86
+ size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
87
+ replace=False)
88
+
89
+ # crop masks; save masks together with input image
90
+ mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
91
+ for i, idx in enumerate(mask_indices):
92
+ cur_image, cur_mask = filtered_image_mask_pairs[idx]
93
+ cur_basename = mask_basename + f'_crop{i:03d}'
94
+ Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
95
+ mode='L').save(cur_basename + f'_mask{i:03d}.png')
96
+ cur_image.save(cur_basename + '.png')
97
+ except KeyboardInterrupt:
98
+ return
99
+ except Exception as ex:
100
+ print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')
101
+
102
+
103
+ @hydra.main(config_path='../configs/data_gen/whydra', config_name='random_medium_256.yaml')
104
+ def main(config: OmegaConf):
105
+ if not config.indir.endswith('/'):
106
+ config.indir += '/'
107
+
108
+ os.makedirs(config.outdir, exist_ok=True)
109
+
110
+ in_files = list(glob.glob(os.path.join(config.indir, '**', f'*.{config.location.extension}'),
111
+ recursive=True))
112
+ if config.n_jobs == 0:
113
+ process_images(in_files, config.indir, config.outdir, config)
114
+ else:
115
+ in_files_n = len(in_files)
116
+ chunk_size = in_files_n // config.n_jobs + (1 if in_files_n % config.n_jobs > 0 else 0)
117
+ Parallel(n_jobs=config.n_jobs)(
118
+ delayed(process_images)(in_files[start:start+chunk_size], config.indir, config.outdir, config)
119
+ for start in range(0, len(in_files), chunk_size)
120
+ )
121
+
122
+
123
+ if __name__ == '__main__':
124
+ main()
bin/gen_outpainting_dataset.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import glob
3
+ import logging
4
+ import os
5
+ import shutil
6
+ import sys
7
+ import traceback
8
+
9
+ from saicinpainting.evaluation.data import load_image
10
+ from saicinpainting.evaluation.utils import move_to_device
11
+
12
+ os.environ['OMP_NUM_THREADS'] = '1'
13
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
14
+ os.environ['MKL_NUM_THREADS'] = '1'
15
+ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
16
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
17
+
18
+ import cv2
19
+ import hydra
20
+ import numpy as np
21
+ import torch
22
+ import tqdm
23
+ import yaml
24
+ from omegaconf import OmegaConf
25
+ from torch.utils.data._utils.collate import default_collate
26
+
27
+ from saicinpainting.training.data.datasets import make_default_val_dataset
28
+ from saicinpainting.training.trainers import load_checkpoint
29
+ from saicinpainting.utils import register_debug_signal_handlers
30
+
31
+ LOGGER = logging.getLogger(__name__)
32
+
33
+
34
+ def main(args):
35
+ try:
36
+ if not args.indir.endswith('/'):
37
+ args.indir += '/'
38
+
39
+ for in_img in glob.glob(os.path.join(args.indir, '**', '*' + args.img_suffix), recursive=True):
40
+ if 'mask' in os.path.basename(in_img):
41
+ continue
42
+
43
+ out_img_path = os.path.join(args.outdir, os.path.splitext(in_img[len(args.indir):])[0] + '.png')
44
+ out_mask_path = f'{os.path.splitext(out_img_path)[0]}_mask.png'
45
+
46
+ os.makedirs(os.path.dirname(out_img_path), exist_ok=True)
47
+
48
+ img = load_image(in_img)
49
+ height, width = img.shape[1:]
50
+ pad_h, pad_w = int(height * args.coef / 2), int(width * args.coef / 2)
51
+
52
+ mask = np.zeros((height, width), dtype='uint8')
53
+
54
+ if args.expand:
55
+ img = np.pad(img, ((0, 0), (pad_h, pad_h), (pad_w, pad_w)))
56
+ mask = np.pad(mask, ((pad_h, pad_h), (pad_w, pad_w)), mode='constant', constant_values=255)
57
+ else:
58
+ mask[:pad_h] = 255
59
+ mask[-pad_h:] = 255
60
+ mask[:, :pad_w] = 255
61
+ mask[:, -pad_w:] = 255
62
+
63
+ # img = np.pad(img, ((0, 0), (pad_h * 2, pad_h * 2), (pad_w * 2, pad_w * 2)), mode='symmetric')
64
+ # mask = np.pad(mask, ((pad_h * 2, pad_h * 2), (pad_w * 2, pad_w * 2)), mode = 'symmetric')
65
+
66
+ img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0, 255).astype('uint8')
67
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
68
+ cv2.imwrite(out_img_path, img)
69
+
70
+ cv2.imwrite(out_mask_path, mask)
71
+ except KeyboardInterrupt:
72
+ LOGGER.warning('Interrupted by user')
73
+ except Exception as ex:
74
+ LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
75
+ sys.exit(1)
76
+
77
+
78
+ if __name__ == '__main__':
79
+ import argparse
80
+
81
+ aparser = argparse.ArgumentParser()
82
+ aparser.add_argument('indir', type=str, help='Root directory with images')
83
+ aparser.add_argument('outdir', type=str, help='Where to store results')
84
+ aparser.add_argument('--img-suffix', type=str, default='.png', help='Input image extension')
85
+ aparser.add_argument('--expand', action='store_true', help='Generate mask by padding (true) or by cropping (false)')
86
+ aparser.add_argument('--coef', type=float, default=0.2, help='How much to crop/expand in order to get masks')
87
+
88
+ main(aparser.parse_args())
bin/make_checkpoint.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import shutil
5
+
6
+ import torch
7
+
8
+
9
+ def get_checkpoint_files(s):
10
+ s = s.strip()
11
+ if ',' in s:
12
+ return [get_checkpoint_files(chunk) for chunk in s.split(',')]
13
+ return 'last.ckpt' if s == 'last' else f'{s}.ckpt'
14
+
15
+
16
+ def main(args):
17
+ checkpoint_fnames = get_checkpoint_files(args.epochs)
18
+ if isinstance(checkpoint_fnames, str):
19
+ checkpoint_fnames = [checkpoint_fnames]
20
+ assert len(checkpoint_fnames) >= 1
21
+
22
+ checkpoint_path = os.path.join(args.indir, 'models', checkpoint_fnames[0])
23
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
24
+ del checkpoint['optimizer_states']
25
+
26
+ if len(checkpoint_fnames) > 1:
27
+ for fname in checkpoint_fnames[1:]:
28
+ print('sum', fname)
29
+ sum_tensors_cnt = 0
30
+ other_cp = torch.load(os.path.join(args.indir, 'models', fname), map_location='cpu')
31
+ for k in checkpoint['state_dict'].keys():
32
+ if checkpoint['state_dict'][k].dtype is torch.float:
33
+ checkpoint['state_dict'][k].data.add_(other_cp['state_dict'][k].data)
34
+ sum_tensors_cnt += 1
35
+ print('summed', sum_tensors_cnt, 'tensors')
36
+
37
+ for k in checkpoint['state_dict'].keys():
38
+ if checkpoint['state_dict'][k].dtype is torch.float:
39
+ checkpoint['state_dict'][k].data.mul_(1 / float(len(checkpoint_fnames)))
40
+
41
+ state_dict = checkpoint['state_dict']
42
+
43
+ if not args.leave_discriminators:
44
+ for k in list(state_dict.keys()):
45
+ if k.startswith('discriminator.'):
46
+ del state_dict[k]
47
+
48
+ if not args.leave_losses:
49
+ for k in list(state_dict.keys()):
50
+ if k.startswith('loss_'):
51
+ del state_dict[k]
52
+
53
+ out_checkpoint_path = os.path.join(args.outdir, 'models', 'best.ckpt')
54
+ os.makedirs(os.path.dirname(out_checkpoint_path), exist_ok=True)
55
+
56
+ torch.save(checkpoint, out_checkpoint_path)
57
+
58
+ shutil.copy2(os.path.join(args.indir, 'config.yaml'),
59
+ os.path.join(args.outdir, 'config.yaml'))
60
+
61
+
62
+ if __name__ == '__main__':
63
+ import argparse
64
+
65
+ aparser = argparse.ArgumentParser()
66
+ aparser.add_argument('indir',
67
+ help='Path to directory with output of training '
68
+ '(i.e. directory, which has samples, modules, config.yaml and train.log')
69
+ aparser.add_argument('outdir',
70
+ help='Where to put minimal checkpoint, which can be consumed by "bin/predict.py"')
71
+ aparser.add_argument('--epochs', type=str, default='last',
72
+ help='Which checkpoint to take. '
73
+ 'Can be "last" or integer - number of epoch')
74
+ aparser.add_argument('--leave-discriminators', action='store_true',
75
+ help='If enabled, the state of discriminators will not be removed from the checkpoint')
76
+ aparser.add_argument('--leave-losses', action='store_true',
77
+ help='If enabled, weights of nn-based losses (e.g. perceptual) will not be removed')
78
+
79
+ main(aparser.parse_args())
bin/mask_example.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from skimage import io
3
+ from skimage.transform import resize
4
+
5
+ from saicinpainting.evaluation.masks.mask import SegmentationMask
6
+
7
+ im = io.imread('imgs/ex4.jpg')
8
+ im = resize(im, (512, 1024), anti_aliasing=True)
9
+ mask_seg = SegmentationMask(num_variants_per_mask=10)
10
+ mask_examples = mask_seg.get_masks(im)
11
+ for i, example in enumerate(mask_examples):
12
+ plt.imshow(example)
13
+ plt.show()
14
+ plt.imsave(f'tmp/img_masks/{i}.png', example)
bin/paper_runfiles/blur_tests.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/env bash
2
+ #
3
+ ## !!! file set to make test_large_30k from the vanilla test_large: configs/test_large_30k.lst
4
+ #
5
+ ## paths to data are valid for mml7
6
+ #PLACES_ROOT="/data/inpainting/Places365"
7
+ #OUT_DIR="/data/inpainting/paper_data/Places365_val_test"
8
+ #
9
+ #source "$(dirname $0)/env.sh"
10
+ #
11
+ #for datadir in test_large_30k # val_large
12
+ #do
13
+ # for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
14
+ # do
15
+ # "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
16
+ # "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 8
17
+ #
18
+ # "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
19
+ # done
20
+ #
21
+ # for conf in segm_256 segm_512
22
+ # do
23
+ # "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
24
+ # "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 2
25
+ #
26
+ # "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
27
+ # done
28
+ #done
29
+ #
30
+ #IN_DIR="/data/inpainting/paper_data/Places365_val_test/test_large_30k/random_medium_512"
31
+ #PRED_DIR="/data/inpainting/predictions/final/images/r.suvorov_2021-03-05_17-08-35_train_ablv2_work_resume_epoch37/random_medium_512"
32
+ #BLUR_OUT_DIR="/data/inpainting/predictions/final/blur/images"
33
+ #
34
+ #for b in 0.1
35
+ #
36
+ #"$BINDIR/blur_predicts.py" "$BASEDIR/../../configs/eval2.yaml" "$CUR_IN_DIR" "$CUR_OUT_DIR" "$CUR_EVAL_DIR"
37
+ #
bin/paper_runfiles/env.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ DIRNAME="$(dirname $0)"
2
+ DIRNAME="$(realpath ""$DIRNAME"")"
3
+
4
+ BINDIR="$DIRNAME/.."
5
+ SRCDIR="$BINDIR/.."
6
+ CONFIGDIR="$SRCDIR/configs"
7
+
8
+ export PYTHONPATH="$SRCDIR:$PYTHONPATH"
bin/paper_runfiles/find_best_checkpoint.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+
4
+ import os
5
+ from argparse import ArgumentParser
6
+
7
+
8
+ def ssim_fid100_f1(metrics, fid_scale=100):
9
+ ssim = metrics.loc['total', 'ssim']['mean']
10
+ fid = metrics.loc['total', 'fid']['mean']
11
+ fid_rel = max(0, fid_scale - fid) / fid_scale
12
+ f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
13
+ return f1
14
+
15
+
16
+ def find_best_checkpoint(model_list, models_dir):
17
+ with open(model_list) as f:
18
+ models = [m.strip() for m in f.readlines()]
19
+ with open(f'{model_list}_best', 'w') as f:
20
+ for model in models:
21
+ print(model)
22
+ best_f1 = 0
23
+ best_epoch = 0
24
+ best_step = 0
25
+ with open(os.path.join(models_dir, model, 'train.log')) as fm:
26
+ lines = fm.readlines()
27
+ for line_index in range(len(lines)):
28
+ line = lines[line_index]
29
+ if 'Validation metrics after epoch' in line:
30
+ sharp_index = line.index('#')
31
+ cur_ep = line[sharp_index + 1:]
32
+ comma_index = cur_ep.index(',')
33
+ cur_ep = int(cur_ep[:comma_index])
34
+ total_index = line.index('total ')
35
+ step = int(line[total_index:].split()[1].strip())
36
+ total_line = lines[line_index + 5]
37
+ if not total_line.startswith('total'):
38
+ continue
39
+ words = total_line.strip().split()
40
+ f1 = float(words[-1])
41
+ print(f'\tEpoch: {cur_ep}, f1={f1}')
42
+ if f1 > best_f1:
43
+ best_f1 = f1
44
+ best_epoch = cur_ep
45
+ best_step = step
46
+ f.write(f'{model}\t{best_epoch}\t{best_step}\t{best_f1}\n')
47
+
48
+
49
+ if __name__ == '__main__':
50
+ parser = ArgumentParser()
51
+ parser.add_argument('model_list')
52
+ parser.add_argument('models_dir')
53
+ args = parser.parse_args()
54
+ find_best_checkpoint(args.model_list, args.models_dir)
bin/paper_runfiles/generate_test_celeba-hq.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # paths to data are valid for mml-ws01
4
+ OUT_DIR="/media/inpainting/paper_data/CelebA-HQ_val_test"
5
+
6
+ source "$(dirname $0)/env.sh"
7
+
8
+ for datadir in "val" "test"
9
+ do
10
+ for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
11
+ do
12
+ "$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-celeba-hq \
13
+ location.out_dir=$OUT_DIR cropping.out_square_crop=False
14
+
15
+ "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
16
+ done
17
+ done
bin/paper_runfiles/generate_test_ffhq.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # paths to data are valid for mml-ws01
4
+ OUT_DIR="/media/inpainting/paper_data/FFHQ_val"
5
+
6
+ source "$(dirname $0)/env.sh"
7
+
8
+ for datadir in test
9
+ do
10
+ for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
11
+ do
12
+ "$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-ffhq \
13
+ location.out_dir=$OUT_DIR cropping.out_square_crop=False
14
+
15
+ "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
16
+ done
17
+ done
bin/paper_runfiles/generate_test_paris.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # paths to data are valid for mml-ws01
4
+ OUT_DIR="/media/inpainting/paper_data/Paris_StreetView_Dataset_val"
5
+
6
+ source "$(dirname $0)/env.sh"
7
+
8
+ for datadir in paris_eval_gt
9
+ do
10
+ for conf in random_thin_256 random_medium_256 random_thick_256 segm_256
11
+ do
12
+ "$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-paris \
13
+ location.out_dir=OUT_DIR cropping.out_square_crop=False cropping.out_min_size=227
14
+
15
+ "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
16
+ done
17
+ done
bin/paper_runfiles/generate_test_paris_256.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # paths to data are valid for mml-ws01
4
+ OUT_DIR="/media/inpainting/paper_data/Paris_StreetView_Dataset_val_256"
5
+
6
+ source "$(dirname $0)/env.sh"
7
+
8
+ for datadir in paris_eval_gt
9
+ do
10
+ for conf in random_thin_256 random_medium_256 random_thick_256 segm_256
11
+ do
12
+ "$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-paris \
13
+ location.out_dir=$OUT_DIR cropping.out_square_crop=False cropping.out_min_size=256
14
+
15
+ "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
16
+ done
17
+ done
bin/paper_runfiles/generate_val_test.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # !!! file set to make test_large_30k from the vanilla test_large: configs/test_large_30k.lst
4
+
5
+ # paths to data are valid for mml7
6
+ PLACES_ROOT="/data/inpainting/Places365"
7
+ OUT_DIR="/data/inpainting/paper_data/Places365_val_test"
8
+
9
+ source "$(dirname $0)/env.sh"
10
+
11
+ for datadir in test_large_30k # val_large
12
+ do
13
+ for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
14
+ do
15
+ "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
16
+ "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 8
17
+
18
+ "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
19
+ done
20
+
21
+ for conf in segm_256 segm_512
22
+ do
23
+ "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
24
+ "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 2
25
+
26
+ "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
27
+ done
28
+ done
bin/paper_runfiles/predict_inner_features.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # paths to data are valid for mml7
4
+
5
+ source "$(dirname $0)/env.sh"
6
+
7
+ "$BINDIR/predict_inner_features.py" \
8
+ -cn default_inner_features_ffc \
9
+ model.path="/data/inpainting/paper_data/final_models/ours/r.suvorov_2021-03-05_17-34-05_train_ablv2_work_ffc075_resume_epoch39" \
10
+ indir="/data/inpainting/paper_data/inner_features_vis/input/" \
11
+ outdir="/data/inpainting/paper_data/inner_features_vis/output/ffc" \
12
+ dataset.img_suffix=.png
13
+
14
+
15
+ "$BINDIR/predict_inner_features.py" \
16
+ -cn default_inner_features_work \
17
+ model.path="/data/inpainting/paper_data/final_models/ours/r.suvorov_2021-03-05_17-08-35_train_ablv2_work_resume_epoch37" \
18
+ indir="/data/inpainting/paper_data/inner_features_vis/input/" \
19
+ outdir="/data/inpainting/paper_data/inner_features_vis/output/work" \
20
+ dataset.img_suffix=.png
bin/paper_runfiles/update_test_data_stats.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # paths to data are valid for mml7
4
+
5
+ source "$(dirname $0)/env.sh"
6
+
7
+ #INDIR="/data/inpainting/paper_data/Places365_val_test/test_large_30k"
8
+ #
9
+ #for dataset in random_medium_256 random_medium_512 random_thick_256 random_thick_512 random_thin_256 random_thin_512
10
+ #do
11
+ # "$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
12
+ #done
13
+ #
14
+ #"$BINDIR/calc_dataset_stats.py" "/data/inpainting/evalset2" "/data/inpainting/evalset2_stats2"
15
+
16
+
17
+ INDIR="/data/inpainting/paper_data/CelebA-HQ_val_test/test"
18
+
19
+ for dataset in random_medium_256 random_thick_256 random_thin_256
20
+ do
21
+ "$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
22
+ done
23
+
24
+
25
+ INDIR="/data/inpainting/paper_data/Paris_StreetView_Dataset_val_256/paris_eval_gt"
26
+
27
+ for dataset in random_medium_256 random_thick_256 random_thin_256
28
+ do
29
+ "$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
30
+ done
bin/predict.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Example command:
4
+ # ./bin/predict.py \
5
+ # model.path=<path to checkpoint, prepared by make_checkpoint.py> \
6
+ # indir=<path to input data> \
7
+ # outdir=<where to store predicts>
8
+
9
+ import logging
10
+ import os
11
+ import sys
12
+ import traceback
13
+
14
+ from saicinpainting.evaluation.utils import move_to_device
15
+
16
+ os.environ['OMP_NUM_THREADS'] = '1'
17
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
18
+ os.environ['MKL_NUM_THREADS'] = '1'
19
+ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
20
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
21
+
22
+ import cv2
23
+ import hydra
24
+ import numpy as np
25
+ import torch
26
+ import tqdm
27
+ import yaml
28
+ from omegaconf import OmegaConf
29
+ from torch.utils.data._utils.collate import default_collate
30
+
31
+ from saicinpainting.training.data.datasets import make_default_val_dataset
32
+ from saicinpainting.training.trainers import load_checkpoint
33
+ from saicinpainting.utils import register_debug_signal_handlers
34
+
35
+ LOGGER = logging.getLogger(__name__)
36
+
37
+
38
+ @hydra.main(config_path='../configs/prediction', config_name='default.yaml')
39
+ def main(predict_config: OmegaConf):
40
+ try:
41
+ register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
42
+
43
+ device = torch.device(predict_config.device)
44
+
45
+ train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
46
+ with open(train_config_path, 'r') as f:
47
+ train_config = OmegaConf.create(yaml.safe_load(f))
48
+
49
+ train_config.training_model.predict_only = True
50
+
51
+ out_ext = predict_config.get('out_ext', '.png')
52
+
53
+ checkpoint_path = os.path.join(predict_config.model.path,
54
+ 'models',
55
+ predict_config.model.checkpoint)
56
+ model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
57
+ model.freeze()
58
+ model.to(device)
59
+
60
+ if not predict_config.indir.endswith('/'):
61
+ predict_config.indir += '/'
62
+
63
+ dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
64
+ with torch.no_grad():
65
+ for img_i in tqdm.trange(len(dataset)):
66
+ mask_fname = dataset.mask_filenames[img_i]
67
+ cur_out_fname = os.path.join(
68
+ predict_config.outdir,
69
+ os.path.splitext(mask_fname[len(predict_config.indir):])[0] + out_ext
70
+ )
71
+ os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
72
+
73
+ batch = move_to_device(default_collate([dataset[img_i]]), device)
74
+ batch['mask'] = (batch['mask'] > 0) * 1
75
+ batch = model(batch)
76
+ cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
77
+
78
+ cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
79
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
80
+ cv2.imwrite(cur_out_fname, cur_res)
81
+ except KeyboardInterrupt:
82
+ LOGGER.warning('Interrupted by user')
83
+ except Exception as ex:
84
+ LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
85
+ sys.exit(1)
86
+
87
+
88
+ if __name__ == '__main__':
89
+ main()
bin/predict_inner_features.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Example command:
4
+ # ./bin/predict.py \
5
+ # model.path=<path to checkpoint, prepared by make_checkpoint.py> \
6
+ # indir=<path to input data> \
7
+ # outdir=<where to store predicts>
8
+
9
+ import logging
10
+ import os
11
+ import sys
12
+ import traceback
13
+
14
+ from saicinpainting.evaluation.utils import move_to_device
15
+
16
+ os.environ['OMP_NUM_THREADS'] = '1'
17
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
18
+ os.environ['MKL_NUM_THREADS'] = '1'
19
+ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
20
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
21
+
22
+ import cv2
23
+ import hydra
24
+ import numpy as np
25
+ import torch
26
+ import tqdm
27
+ import yaml
28
+ from omegaconf import OmegaConf
29
+ from torch.utils.data._utils.collate import default_collate
30
+
31
+ from saicinpainting.training.data.datasets import make_default_val_dataset
32
+ from saicinpainting.training.trainers import load_checkpoint, DefaultInpaintingTrainingModule
33
+ from saicinpainting.utils import register_debug_signal_handlers, get_shape
34
+
35
+ LOGGER = logging.getLogger(__name__)
36
+
37
+
38
+ @hydra.main(config_path='../configs/prediction', config_name='default_inner_features.yaml')
39
+ def main(predict_config: OmegaConf):
40
+ try:
41
+ register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
42
+
43
+ device = torch.device(predict_config.device)
44
+
45
+ train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
46
+ with open(train_config_path, 'r') as f:
47
+ train_config = OmegaConf.create(yaml.safe_load(f))
48
+
49
+ checkpoint_path = os.path.join(predict_config.model.path, 'models', predict_config.model.checkpoint)
50
+ model = load_checkpoint(train_config, checkpoint_path, strict=False)
51
+ model.freeze()
52
+ model.to(device)
53
+
54
+ assert isinstance(model, DefaultInpaintingTrainingModule), 'Only DefaultInpaintingTrainingModule is supported'
55
+ assert isinstance(getattr(model.generator, 'model', None), torch.nn.Sequential)
56
+
57
+ if not predict_config.indir.endswith('/'):
58
+ predict_config.indir += '/'
59
+
60
+ dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
61
+
62
+ max_level = max(predict_config.levels)
63
+
64
+ with torch.no_grad():
65
+ for img_i in tqdm.trange(len(dataset)):
66
+ mask_fname = dataset.mask_filenames[img_i]
67
+ cur_out_fname = os.path.join(predict_config.outdir, os.path.splitext(mask_fname[len(predict_config.indir):])[0])
68
+ os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
69
+
70
+ batch = move_to_device(default_collate([dataset[img_i]]), device)
71
+
72
+ img = batch['image']
73
+ mask = batch['mask']
74
+ mask[:] = 0
75
+ mask_h, mask_w = mask.shape[-2:]
76
+ mask[:, :,
77
+ mask_h // 2 - predict_config.hole_radius : mask_h // 2 + predict_config.hole_radius,
78
+ mask_w // 2 - predict_config.hole_radius : mask_w // 2 + predict_config.hole_radius] = 1
79
+
80
+ masked_img = torch.cat([img * (1 - mask), mask], dim=1)
81
+
82
+ feats = masked_img
83
+ for level_i, level in enumerate(model.generator.model):
84
+ feats = level(feats)
85
+ if level_i in predict_config.levels:
86
+ cur_feats = torch.cat([f for f in feats if torch.is_tensor(f)], dim=1) \
87
+ if isinstance(feats, tuple) else feats
88
+
89
+ if predict_config.slice_channels:
90
+ cur_feats = cur_feats[:, slice(*predict_config.slice_channels)]
91
+
92
+ cur_feat = cur_feats.pow(2).mean(1).pow(0.5).clone()
93
+ cur_feat -= cur_feat.min()
94
+ cur_feat /= cur_feat.std()
95
+ cur_feat = cur_feat.clamp(0, 1) / 1
96
+ cur_feat = cur_feat.cpu().numpy()[0]
97
+ cur_feat *= 255
98
+ cur_feat = np.clip(cur_feat, 0, 255).astype('uint8')
99
+ cv2.imwrite(cur_out_fname + f'_lev{level_i:02d}_norm.png', cur_feat)
100
+
101
+ # for channel_i in predict_config.channels:
102
+ #
103
+ # cur_feat = cur_feats[0, channel_i].clone().detach().cpu().numpy()
104
+ # cur_feat -= cur_feat.min()
105
+ # cur_feat /= cur_feat.max()
106
+ # cur_feat *= 255
107
+ # cur_feat = np.clip(cur_feat, 0, 255).astype('uint8')
108
+ # cv2.imwrite(cur_out_fname + f'_lev{level_i}_ch{channel_i}.png', cur_feat)
109
+ elif level_i >= max_level:
110
+ break
111
+ except KeyboardInterrupt:
112
+ LOGGER.warning('Interrupted by user')
113
+ except Exception as ex:
114
+ LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
115
+ sys.exit(1)
116
+
117
+
118
+ if __name__ == '__main__':
119
+ main()
bin/report_from_tb.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import glob
4
+ import os
5
+ import re
6
+
7
+ import tensorflow as tf
8
+ from torch.utils.tensorboard import SummaryWriter
9
+
10
+
11
+ GROUPING_RULES = [
12
+ re.compile(r'^(?P<group>train|test|val|extra_val_.*?(256|512))_(?P<title>.*)', re.I)
13
+ ]
14
+
15
+
16
+ DROP_RULES = [
17
+ re.compile(r'_std$', re.I)
18
+ ]
19
+
20
+
21
+ def need_drop(tag):
22
+ for rule in DROP_RULES:
23
+ if rule.search(tag):
24
+ return True
25
+ return False
26
+
27
+
28
+ def get_group_and_title(tag):
29
+ for rule in GROUPING_RULES:
30
+ match = rule.search(tag)
31
+ if match is None:
32
+ continue
33
+ return match.group('group'), match.group('title')
34
+ return None, None
35
+
36
+
37
+ def main(args):
38
+ os.makedirs(args.outdir, exist_ok=True)
39
+
40
+ ignored_events = set()
41
+
42
+ for orig_fname in glob.glob(args.inglob):
43
+ cur_dirpath = os.path.dirname(orig_fname) # remove filename, this should point to "version_0" directory
44
+ subdirname = os.path.basename(cur_dirpath) # == "version_0" most of time
45
+ exp_root_path = os.path.dirname(cur_dirpath) # remove "version_0"
46
+ exp_name = os.path.basename(exp_root_path)
47
+
48
+ writers_by_group = {}
49
+
50
+ for e in tf.compat.v1.train.summary_iterator(orig_fname):
51
+ for v in e.summary.value:
52
+ if need_drop(v.tag):
53
+ continue
54
+
55
+ cur_group, cur_title = get_group_and_title(v.tag)
56
+ if cur_group is None:
57
+ if v.tag not in ignored_events:
58
+ print(f'WARNING: Could not detect group for {v.tag}, ignoring it')
59
+ ignored_events.add(v.tag)
60
+ continue
61
+
62
+ cur_writer = writers_by_group.get(cur_group, None)
63
+ if cur_writer is None:
64
+ if args.include_version:
65
+ cur_outdir = os.path.join(args.outdir, exp_name, f'{subdirname}_{cur_group}')
66
+ else:
67
+ cur_outdir = os.path.join(args.outdir, exp_name, cur_group)
68
+ cur_writer = SummaryWriter(cur_outdir)
69
+ writers_by_group[cur_group] = cur_writer
70
+
71
+ cur_writer.add_scalar(cur_title, v.simple_value, global_step=e.step, walltime=e.wall_time)
72
+
73
+
74
+ if __name__ == '__main__':
75
+ import argparse
76
+
77
+ aparser = argparse.ArgumentParser()
78
+ aparser.add_argument('inglob', type=str)
79
+ aparser.add_argument('outdir', type=str)
80
+ aparser.add_argument('--include-version', action='store_true',
81
+ help='Include subdirectory name e.g. "version_0" into output path')
82
+
83
+ main(aparser.parse_args())
bin/sample_from_dataset.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+
5
+ import numpy as np
6
+ import tqdm
7
+ from skimage import io
8
+ from skimage.segmentation import mark_boundaries
9
+
10
+ from saicinpainting.evaluation.data import InpaintingDataset
11
+ from saicinpainting.evaluation.vis import save_item_for_vis
12
+
13
+ def save_mask_for_sidebyside(item, out_file):
14
+ mask = item['mask']# > 0.5
15
+ if mask.ndim == 3:
16
+ mask = mask[0]
17
+ mask = np.clip(mask * 255, 0, 255).astype('uint8')
18
+ io.imsave(out_file, mask)
19
+
20
+ def save_img_for_sidebyside(item, out_file):
21
+ img = np.transpose(item['image'], (1, 2, 0))
22
+ img = np.clip(img * 255, 0, 255).astype('uint8')
23
+ io.imsave(out_file, img)
24
+
25
+ def save_masked_img_for_sidebyside(item, out_file):
26
+ mask = item['mask']
27
+ img = item['image']
28
+
29
+ img = (1-mask) * img + mask
30
+ img = np.transpose(img, (1, 2, 0))
31
+
32
+ img = np.clip(img * 255, 0, 255).astype('uint8')
33
+ io.imsave(out_file, img)
34
+
35
+ def main(args):
36
+ dataset = InpaintingDataset(args.datadir, img_suffix='.png')
37
+
38
+ area_bins = np.linspace(0, 1, args.area_bins + 1)
39
+
40
+ heights = []
41
+ widths = []
42
+ image_areas = []
43
+ hole_areas = []
44
+ hole_area_percents = []
45
+ area_bins_count = np.zeros(args.area_bins)
46
+ area_bin_titles = [f'{area_bins[i] * 100:.0f}-{area_bins[i + 1] * 100:.0f}' for i in range(args.area_bins)]
47
+
48
+ bin2i = [[] for _ in range(args.area_bins)]
49
+
50
+ for i, item in enumerate(tqdm.tqdm(dataset)):
51
+ h, w = item['image'].shape[1:]
52
+ heights.append(h)
53
+ widths.append(w)
54
+ full_area = h * w
55
+ image_areas.append(full_area)
56
+ hole_area = (item['mask'] == 1).sum()
57
+ hole_areas.append(hole_area)
58
+ hole_percent = hole_area / full_area
59
+ hole_area_percents.append(hole_percent)
60
+ bin_i = np.clip(np.searchsorted(area_bins, hole_percent) - 1, 0, len(area_bins_count) - 1)
61
+ area_bins_count[bin_i] += 1
62
+ bin2i[bin_i].append(i)
63
+
64
+ os.makedirs(args.outdir, exist_ok=True)
65
+
66
+ for bin_i in range(args.area_bins):
67
+ bindir = os.path.join(args.outdir, area_bin_titles[bin_i])
68
+ os.makedirs(bindir, exist_ok=True)
69
+ bin_idx = bin2i[bin_i]
70
+ for sample_i in np.random.choice(bin_idx, size=min(len(bin_idx), args.samples_n), replace=False):
71
+ item = dataset[sample_i]
72
+ path = os.path.join(bindir, dataset.img_filenames[sample_i].split('/')[-1])
73
+ save_masked_img_for_sidebyside(item, path)
74
+
75
+
76
+ if __name__ == '__main__':
77
+ import argparse
78
+
79
+ aparser = argparse.ArgumentParser()
80
+ aparser.add_argument('--datadir', type=str,
81
+ help='Path to folder with images and masks (output of gen_mask_dataset.py)')
82
+ aparser.add_argument('--outdir', type=str, help='Where to put results')
83
+ aparser.add_argument('--samples-n', type=int, default=10,
84
+ help='Number of sample images with masks to copy for visualization for each area bin')
85
+ aparser.add_argument('--area-bins', type=int, default=10, help='How many area bins to have')
86
+
87
+ main(aparser.parse_args())
bin/side_by_side.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import random
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
9
+ from saicinpainting.evaluation.utils import load_yaml
10
+ from saicinpainting.training.visualizers.base import visualize_mask_and_images
11
+
12
+
13
+ def main(args):
14
+ config = load_yaml(args.config)
15
+
16
+ datasets = [PrecomputedInpaintingResultsDataset(args.datadir, cur_predictdir, **config.dataset_kwargs)
17
+ for cur_predictdir in args.predictdirs]
18
+ assert len({len(ds) for ds in datasets}) == 1
19
+ len_first = len(datasets[0])
20
+
21
+ indices = list(range(len_first))
22
+ if len_first > args.max_n:
23
+ indices = sorted(random.sample(indices, args.max_n))
24
+
25
+ os.makedirs(args.outpath, exist_ok=True)
26
+
27
+ filename2i = {}
28
+
29
+ keys = ['image'] + [i for i in range(len(datasets))]
30
+ for img_i in indices:
31
+ try:
32
+ mask_fname = os.path.basename(datasets[0].mask_filenames[img_i])
33
+ if mask_fname in filename2i:
34
+ filename2i[mask_fname] += 1
35
+ idx = filename2i[mask_fname]
36
+ mask_fname_only, ext = os.path.split(mask_fname)
37
+ mask_fname = f'{mask_fname_only}_{idx}{ext}'
38
+ else:
39
+ filename2i[mask_fname] = 1
40
+
41
+ cur_vis_dict = datasets[0][img_i]
42
+ for ds_i, ds in enumerate(datasets):
43
+ cur_vis_dict[ds_i] = ds[img_i]['inpainted']
44
+
45
+ vis_img = visualize_mask_and_images(cur_vis_dict, keys,
46
+ last_without_mask=False,
47
+ mask_only_first=True,
48
+ black_mask=args.black)
49
+ vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8')
50
+
51
+ out_fname = os.path.join(args.outpath, mask_fname)
52
+
53
+
54
+
55
+ vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)
56
+ cv2.imwrite(out_fname, vis_img)
57
+ except Exception as ex:
58
+ print(f'Could not process {img_i} due to {ex}')
59
+
60
+
61
+ if __name__ == '__main__':
62
+ import argparse
63
+
64
+ aparser = argparse.ArgumentParser()
65
+ aparser.add_argument('--max-n', type=int, default=100, help='Maximum number of images to print')
66
+ aparser.add_argument('--black', action='store_true', help='Whether to fill mask on GT with black')
67
+ aparser.add_argument('config', type=str, help='Path to evaluation config (e.g. configs/eval1.yaml)')
68
+ aparser.add_argument('outpath', type=str, help='Where to put results')
69
+ aparser.add_argument('datadir', type=str,
70
+ help='Path to folder with images and masks')
71
+ aparser.add_argument('predictdirs', type=str,
72
+ nargs='+',
73
+ help='Path to folders with predicts')
74
+
75
+
76
+ main(aparser.parse_args())
bin/split_tar.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+
4
+ import tqdm
5
+ import webdataset as wds
6
+
7
+
8
+ def main(args):
9
+ input_dataset = wds.Dataset(args.infile)
10
+ output_dataset = wds.ShardWriter(args.outpattern)
11
+ for rec in tqdm.tqdm(input_dataset):
12
+ output_dataset.write(rec)
13
+
14
+
15
+ if __name__ == '__main__':
16
+ import argparse
17
+
18
+ aparser = argparse.ArgumentParser()
19
+ aparser.add_argument('infile', type=str)
20
+ aparser.add_argument('outpattern', type=str)
21
+
22
+ main(aparser.parse_args())
bin/train.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import logging
4
+ import os
5
+ import sys
6
+ import traceback
7
+
8
+ os.environ['OMP_NUM_THREADS'] = '1'
9
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
10
+ os.environ['MKL_NUM_THREADS'] = '1'
11
+ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
12
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
13
+
14
+ import hydra
15
+ from omegaconf import OmegaConf
16
+ from pytorch_lightning import Trainer
17
+ from pytorch_lightning.callbacks import ModelCheckpoint
18
+ from pytorch_lightning.loggers import TensorBoardLogger
19
+ from pytorch_lightning.plugins import DDPPlugin
20
+
21
+ from saicinpainting.training.trainers import make_training_model
22
+ from saicinpainting.utils import register_debug_signal_handlers, handle_ddp_subprocess, handle_ddp_parent_process, \
23
+ handle_deterministic_config
24
+
25
+ LOGGER = logging.getLogger(__name__)
26
+
27
+
28
+ @handle_ddp_subprocess()
29
+ @hydra.main(config_path='../configs/training', config_name='tiny_test.yaml')
30
+ def main(config: OmegaConf):
31
+ try:
32
+ need_set_deterministic = handle_deterministic_config(config)
33
+
34
+ register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
35
+
36
+ is_in_ddp_subprocess = handle_ddp_parent_process()
37
+
38
+ config.visualizer.outdir = os.path.join(os.getcwd(), config.visualizer.outdir)
39
+ if not is_in_ddp_subprocess:
40
+ LOGGER.info(OmegaConf.to_yaml(config))
41
+ OmegaConf.save(config, os.path.join(os.getcwd(), 'config.yaml'))
42
+
43
+ checkpoints_dir = os.path.join(os.getcwd(), 'models')
44
+ os.makedirs(checkpoints_dir, exist_ok=True)
45
+
46
+ # there is no need to suppress this logger in ddp, because it handles rank on its own
47
+ metrics_logger = TensorBoardLogger(config.location.tb_dir, name=os.path.basename(os.getcwd()))
48
+ metrics_logger.log_hyperparams(config)
49
+
50
+ training_model = make_training_model(config)
51
+
52
+ trainer_kwargs = OmegaConf.to_container(config.trainer.kwargs, resolve=True)
53
+ if need_set_deterministic:
54
+ trainer_kwargs['deterministic'] = True
55
+
56
+ trainer = Trainer(
57
+ # there is no need to suppress checkpointing in ddp, because it handles rank on its own
58
+ callbacks=ModelCheckpoint(dirpath=checkpoints_dir, **config.trainer.checkpoint_kwargs),
59
+ logger=metrics_logger,
60
+ default_root_dir=os.getcwd(),
61
+ **trainer_kwargs
62
+ )
63
+ trainer.fit(training_model)
64
+ except KeyboardInterrupt:
65
+ LOGGER.warning('Interrupted by user')
66
+ except Exception as ex:
67
+ LOGGER.critical(f'Training failed due to {ex}:\n{traceback.format_exc()}')
68
+ sys.exit(1)
69
+
70
+
71
+ if __name__ == '__main__':
72
+ main()
canvas.png ADDED
conda_env.yml ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: lama
2
+ channels:
3
+ - defaults
4
+ - conda-forge
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - _openmp_mutex=4.5=1_gnu
8
+ - absl-py=0.13.0=py36h06a4308_0
9
+ - aiohttp=3.7.4.post0=py36h7f8727e_2
10
+ - antlr-python-runtime=4.8=py36h9f0ad1d_2
11
+ - async-timeout=3.0.1=py36h06a4308_0
12
+ - attrs=21.2.0=pyhd3eb1b0_0
13
+ - blas=1.0=mkl
14
+ - blinker=1.4=py36h06a4308_0
15
+ - brotlipy=0.7.0=py36h27cfd23_1003
16
+ - bzip2=1.0.8=h7b6447c_0
17
+ - c-ares=1.17.1=h27cfd23_0
18
+ - ca-certificates=2021.7.5=h06a4308_1
19
+ - cachetools=4.2.2=pyhd3eb1b0_0
20
+ - certifi=2021.5.30=py36h06a4308_0
21
+ - cffi=1.14.6=py36h400218f_0
22
+ - chardet=4.0.0=py36h06a4308_1003
23
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
24
+ - click=8.0.1=pyhd3eb1b0_0
25
+ - cloudpickle=2.0.0=pyhd3eb1b0_0
26
+ - coverage=5.5=py36h27cfd23_2
27
+ - cryptography=3.4.7=py36hd23ed53_0
28
+ - cudatoolkit=10.2.89=hfd86e86_1
29
+ - cycler=0.10.0=py36_0
30
+ - cython=0.29.24=py36h295c915_0
31
+ - cytoolz=0.11.0=py36h7b6447c_0
32
+ - dask-core=1.1.4=py36_1
33
+ - dataclasses=0.8=pyh4f3eec9_6
34
+ - dbus=1.13.18=hb2f20db_0
35
+ - decorator=5.0.9=pyhd3eb1b0_0
36
+ - easydict=1.9=py_0
37
+ - expat=2.4.1=h2531618_2
38
+ - ffmpeg=4.2.2=h20bf706_0
39
+ - fontconfig=2.13.1=h6c09931_0
40
+ - freetype=2.10.4=h5ab3b9f_0
41
+ - fsspec=2021.8.1=pyhd3eb1b0_0
42
+ - future=0.18.2=py36_1
43
+ - glib=2.69.1=h5202010_0
44
+ - gmp=6.2.1=h2531618_2
45
+ - gnutls=3.6.15=he1e5248_0
46
+ - google-auth=1.33.0=pyhd3eb1b0_0
47
+ - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
48
+ - grpcio=1.36.1=py36h2157cd5_1
49
+ - gst-plugins-base=1.14.0=h8213a91_2
50
+ - gstreamer=1.14.0=h28cd5cc_2
51
+ - hydra-core=1.1.0=pyhd8ed1ab_0
52
+ - icu=58.2=he6710b0_3
53
+ - idna=3.2=pyhd3eb1b0_0
54
+ - idna_ssl=1.1.0=py36h06a4308_0
55
+ - imageio=2.9.0=pyhd3eb1b0_0
56
+ - importlib-metadata=4.8.1=py36h06a4308_0
57
+ - importlib_resources=5.2.0=pyhd3eb1b0_1
58
+ - intel-openmp=2021.3.0=h06a4308_3350
59
+ - joblib=1.0.1=pyhd3eb1b0_0
60
+ - jpeg=9b=h024ee3a_2
61
+ - kiwisolver=1.3.1=py36h2531618_0
62
+ - lame=3.100=h7b6447c_0
63
+ - lcms2=2.12=h3be6417_0
64
+ - ld_impl_linux-64=2.35.1=h7274673_9
65
+ - libblas=3.9.0=11_linux64_mkl
66
+ - libcblas=3.9.0=11_linux64_mkl
67
+ - libffi=3.3=he6710b0_2
68
+ - libgcc-ng=9.3.0=h5101ec6_17
69
+ - libgfortran-ng=9.3.0=ha5ec8a7_17
70
+ - libgfortran5=9.3.0=ha5ec8a7_17
71
+ - libgomp=9.3.0=h5101ec6_17
72
+ - libidn2=2.3.2=h7f8727e_0
73
+ - liblapack=3.9.0=11_linux64_mkl
74
+ - libopus=1.3.1=h7b6447c_0
75
+ - libpng=1.6.37=hbc83047_0
76
+ - libprotobuf=3.17.2=h4ff587b_1
77
+ - libstdcxx-ng=9.3.0=hd4cf53a_17
78
+ - libtasn1=4.16.0=h27cfd23_0
79
+ - libtiff=4.2.0=h85742a9_0
80
+ - libunistring=0.9.10=h27cfd23_0
81
+ - libuuid=1.0.3=h1bed415_2
82
+ - libuv=1.40.0=h7b6447c_0
83
+ - libvpx=1.7.0=h439df22_0
84
+ - libwebp-base=1.2.0=h27cfd23_0
85
+ - libxcb=1.14=h7b6447c_0
86
+ - libxml2=2.9.12=h03d6c58_0
87
+ - lz4-c=1.9.3=h295c915_1
88
+ - markdown=3.3.4=py36h06a4308_0
89
+ - matplotlib=3.3.4=py36h06a4308_0
90
+ - matplotlib-base=3.3.4=py36h62a2d02_0
91
+ - mkl=2021.3.0=h06a4308_520
92
+ - multidict=5.1.0=py36h27cfd23_2
93
+ - ncurses=6.2=he6710b0_1
94
+ - nettle=3.7.3=hbbd107a_1
95
+ - networkx=2.2=py36_1
96
+ - ninja=1.10.2=hff7bd54_1
97
+ - numpy=1.19.5=py36hfc0c790_2
98
+ - oauthlib=3.1.1=pyhd3eb1b0_0
99
+ - olefile=0.46=py36_0
100
+ - omegaconf=2.1.1=py36h5fab9bb_0
101
+ - openh264=2.1.0=hd408876_0
102
+ - openjpeg=2.4.0=h3ad879b_0
103
+ - openssl=1.1.1l=h7f8727e_0
104
+ - packaging=21.0=pyhd3eb1b0_0
105
+ - pandas=1.1.5=py36h284efc9_0
106
+ - pcre=8.45=h295c915_0
107
+ - pillow=8.3.1=py36h2c7a002_0
108
+ - pip=21.0.1=py36h06a4308_0
109
+ - protobuf=3.17.2=py36h295c915_0
110
+ - pyasn1=0.4.8=pyhd3eb1b0_0
111
+ - pyasn1-modules=0.2.8=py_0
112
+ - pycparser=2.20=py_2
113
+ - pyjwt=2.1.0=py36h06a4308_0
114
+ - pyopenssl=20.0.1=pyhd3eb1b0_1
115
+ - pyparsing=2.4.7=pyhd3eb1b0_0
116
+ - pyqt=5.9.2=py36h05f1152_2
117
+ - pysocks=1.7.1=py36h06a4308_0
118
+ - python=3.6.13=h12debd9_1
119
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
120
+ - python_abi=3.6=2_cp36m
121
+ - pytz=2021.1=pyhd3eb1b0_0
122
+ - pywavelets=1.1.1=py36h7b6447c_2
123
+ - pyyaml=5.4.1=py36h27cfd23_1
124
+ - qt=5.9.7=h5867ecd_1
125
+ - readline=8.1=h27cfd23_0
126
+ - requests=2.26.0=pyhd3eb1b0_0
127
+ - requests-oauthlib=1.3.0=py_0
128
+ - rsa=4.7.2=pyhd3eb1b0_1
129
+ - scikit-image=0.17.2=py36h284efc9_4
130
+ - scikit-learn=0.24.2=py36ha9443f7_0
131
+ - scipy=1.5.3=py36h9e8f40b_0
132
+ - setuptools=58.0.4=py36h06a4308_0
133
+ - sip=4.19.8=py36hf484d3e_0
134
+ - six=1.16.0=pyhd3eb1b0_0
135
+ - sqlite=3.36.0=hc218d9a_0
136
+ - tabulate=0.8.9=py36h06a4308_0
137
+ - tensorboard=2.4.0=pyhc547734_0
138
+ - tensorboard-plugin-wit=1.6.0=py_0
139
+ - threadpoolctl=2.2.0=pyh0d69192_0
140
+ - tifffile=2020.10.1=py36hdd07704_2
141
+ - tk=8.6.11=h1ccaba5_0
142
+ - toolz=0.11.1=pyhd3eb1b0_0
143
+ - tqdm=4.62.2=pyhd3eb1b0_1
144
+ - typing-extensions=3.10.0.2=hd3eb1b0_0
145
+ - typing_extensions=3.10.0.2=pyh06a4308_0
146
+ - urllib3=1.26.6=pyhd3eb1b0_1
147
+ - werkzeug=2.0.1=pyhd3eb1b0_0
148
+ - wheel=0.37.0=pyhd3eb1b0_1
149
+ - x264=1!157.20191217=h7b6447c_0
150
+ - xz=5.2.5=h7b6447c_0
151
+ - yaml=0.2.5=h7b6447c_0
152
+ - yarl=1.6.3=py36h27cfd23_0
153
+ - zipp=3.5.0=pyhd3eb1b0_0
154
+ - zlib=1.2.11=h7b6447c_3
155
+ - zstd=1.4.9=haebb681_0
156
+ - pip:
157
+ - albumentations==0.5.2
158
+ - braceexpand==0.1.7
159
+ - imgaug==0.4.0
160
+ - kornia==0.5.0
161
+ - opencv-python==4.5.3.56
162
+ - opencv-python-headless==4.5.3.56
163
+ - shapely==1.7.1
164
+ - webdataset==0.1.76
165
+ - wldhx-yadisk-direct==0.0.6
configs/analyze_mask_errors.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ dataset_kwargs:
2
+ img_suffix: .jpg
3
+ inpainted_suffix: .jpg
4
+
5
+ take_global_top: 30
6
+ take_worst_best_top: 30
7
+ take_overlapping_top: 30
configs/data_gen/gen_segm_dataset1.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: segmentation
2
+
3
+ mask_generator_kwargs:
4
+ confidence_threshold: 0.5
5
+ max_object_area: 0.5
6
+ min_mask_area: 0.02
7
+ downsample_levels: 6
8
+ num_variants_per_mask: 5
9
+ rigidness_mode: 1
10
+ max_foreground_coverage: 0.3
11
+ max_foreground_intersection: 0.7
12
+ max_mask_intersection: 0.1
13
+ max_hidden_area: 0.1
14
+ max_scale_change: 0.25
15
+ horizontal_flip: True
16
+ max_vertical_shift: 0.2
17
+ position_shuffle: True
18
+
19
+ max_masks_per_image: 5
20
+
21
+ cropping:
22
+ out_min_size: 512
23
+ handle_small_mode: drop
24
+ out_square_crop: True
25
+ crop_min_overlap: 0.5
configs/data_gen/gen_segm_dataset3.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: segmentation
2
+
3
+ mask_generator_kwargs:
4
+ confidence_threshold: 0.5
5
+ max_object_area: 0.5
6
+ min_mask_area: 0.07
7
+ downsample_levels: 6
8
+ num_variants_per_mask: 3
9
+ rigidness_mode: 1
10
+ max_foreground_coverage: 0.4
11
+ max_foreground_intersection: 0.8
12
+ max_mask_intersection: 0.2
13
+ max_hidden_area: 0.1
14
+ max_scale_change: 0.25
15
+ horizontal_flip: True
16
+ max_vertical_shift: 0.3
17
+ position_shuffle: True
18
+
19
+ max_masks_per_image: 3
20
+
21
+ cropping:
22
+ out_min_size: 512
23
+ handle_small_mode: drop
24
+ out_square_crop: True
25
+ crop_min_overlap: 0.5
configs/data_gen/random_medium_256.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 4
7
+ max_times: 5
8
+ max_width: 50
9
+ max_angle: 4
10
+ max_len: 100
11
+
12
+ box_proba: 0.3
13
+ box_kwargs:
14
+ margin: 0
15
+ bbox_min_size: 10
16
+ bbox_max_size: 50
17
+ max_times: 5
18
+ min_times: 1
19
+
20
+ segm_proba: 0
21
+ squares_proba: 0
22
+
23
+ variants_n: 5
24
+
25
+ max_masks_per_image: 1
26
+
27
+ cropping:
28
+ out_min_size: 256
29
+ handle_small_mode: upscale
30
+ out_square_crop: True
31
+ crop_min_overlap: 1
32
+
33
+ max_tamper_area: 0.5
configs/data_gen/random_medium_512.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 4
7
+ max_times: 10
8
+ max_width: 100
9
+ max_angle: 4
10
+ max_len: 200
11
+
12
+ box_proba: 0.3
13
+ box_kwargs:
14
+ margin: 0
15
+ bbox_min_size: 30
16
+ bbox_max_size: 150
17
+ max_times: 5
18
+ min_times: 1
19
+
20
+ segm_proba: 0
21
+ squares_proba: 0
22
+
23
+ variants_n: 5
24
+
25
+ max_masks_per_image: 1
26
+
27
+ cropping:
28
+ out_min_size: 512
29
+ handle_small_mode: upscale
30
+ out_square_crop: True
31
+ crop_min_overlap: 1
32
+
33
+ max_tamper_area: 0.5
configs/data_gen/random_thick_256.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 1
7
+ max_times: 5
8
+ max_width: 100
9
+ max_angle: 4
10
+ max_len: 200
11
+
12
+ box_proba: 0.3
13
+ box_kwargs:
14
+ margin: 10
15
+ bbox_min_size: 30
16
+ bbox_max_size: 150
17
+ max_times: 3
18
+ min_times: 1
19
+
20
+ segm_proba: 0
21
+ squares_proba: 0
22
+
23
+ variants_n: 5
24
+
25
+ max_masks_per_image: 1
26
+
27
+ cropping:
28
+ out_min_size: 256
29
+ handle_small_mode: upscale
30
+ out_square_crop: True
31
+ crop_min_overlap: 1
32
+
33
+ max_tamper_area: 0.5
configs/data_gen/random_thick_512.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 1
7
+ max_times: 5
8
+ max_width: 250
9
+ max_angle: 4
10
+ max_len: 450
11
+
12
+ box_proba: 0.3
13
+ box_kwargs:
14
+ margin: 10
15
+ bbox_min_size: 30
16
+ bbox_max_size: 300
17
+ max_times: 4
18
+ min_times: 1
19
+
20
+ segm_proba: 0
21
+ squares_proba: 0
22
+
23
+ variants_n: 5
24
+
25
+ max_masks_per_image: 1
26
+
27
+ cropping:
28
+ out_min_size: 512
29
+ handle_small_mode: upscale
30
+ out_square_crop: True
31
+ crop_min_overlap: 1
32
+
33
+ max_tamper_area: 0.5
configs/data_gen/random_thin_256.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 4
7
+ max_times: 50
8
+ max_width: 10
9
+ max_angle: 4
10
+ max_len: 40
11
+ box_proba: 0
12
+ segm_proba: 0
13
+ squares_proba: 0
14
+
15
+ variants_n: 5
16
+
17
+ max_masks_per_image: 1
18
+
19
+ cropping:
20
+ out_min_size: 256
21
+ handle_small_mode: upscale
22
+ out_square_crop: True
23
+ crop_min_overlap: 1
24
+
25
+ max_tamper_area: 0.5
configs/data_gen/random_thin_512.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 4
7
+ max_times: 70
8
+ max_width: 20
9
+ max_angle: 4
10
+ max_len: 100
11
+ box_proba: 0
12
+ segm_proba: 0
13
+ squares_proba: 0
14
+
15
+ variants_n: 5
16
+
17
+ max_masks_per_image: 1
18
+
19
+ cropping:
20
+ out_min_size: 512
21
+ handle_small_mode: upscale
22
+ out_square_crop: True
23
+ crop_min_overlap: 1
24
+
25
+ max_tamper_area: 0.5
configs/data_gen/segm_256.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: segmentation
2
+
3
+ mask_generator_kwargs:
4
+ confidence_threshold: 0.5
5
+ max_object_area: 0.5
6
+ min_mask_area: 0.05
7
+ downsample_levels: 6
8
+ num_variants_per_mask: 3
9
+ rigidness_mode: 1
10
+ max_foreground_coverage: 1 # turn off filtering by overlap
11
+ max_foreground_intersection: 1 # turn off filtering by overlap
12
+ max_mask_intersection: 0.2 # the lower this value the higher diversity
13
+ max_hidden_area: 0.5
14
+ max_scale_change: 0.25
15
+ horizontal_flip: True
16
+ max_vertical_shift: 0.3
17
+ position_shuffle: True
18
+
19
+ max_masks_per_image: 1
20
+
21
+ cropping:
22
+ out_min_size: 256
23
+ handle_small_mode: upscale
24
+ out_square_crop: True
25
+ crop_min_overlap: 1
26
+
27
+ max_tamper_area: 0.5
configs/data_gen/segm_512.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: segmentation
2
+
3
+ mask_generator_kwargs:
4
+ confidence_threshold: 0.5
5
+ max_object_area: 0.5
6
+ min_mask_area: 0.05
7
+ downsample_levels: 6
8
+ num_variants_per_mask: 3
9
+ rigidness_mode: 1
10
+ max_foreground_coverage: 1 # turn off filtering by overlap
11
+ max_foreground_intersection: 1 # turn off filtering by overlap
12
+ max_mask_intersection: 0.2 # the lower this value the higher diversity
13
+ max_hidden_area: 0.5
14
+ max_scale_change: 0.25
15
+ horizontal_flip: True
16
+ max_vertical_shift: 0.3
17
+ position_shuffle: True
18
+
19
+ max_masks_per_image: 1
20
+
21
+ cropping:
22
+ out_min_size: 512
23
+ handle_small_mode: upscale
24
+ out_square_crop: True
25
+ crop_min_overlap: 1
26
+
27
+ max_tamper_area: 0.5
configs/data_gen/sr_256.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 0
5
+ box_proba: 0
6
+ segm_proba: 0
7
+ squares_proba: 0
8
+ superres_proba: 1
9
+ superres_kwargs:
10
+ min_step: 2
11
+ max_step: 4
12
+ min_width: 1
13
+ max_width: 3
14
+
15
+ variants_n: 5
16
+
17
+ max_masks_per_image: 1
18
+
19
+ cropping:
20
+ out_min_size: 256
21
+ handle_small_mode: upscale
22
+ out_square_crop: True
23
+ crop_min_overlap: 1
24
+
25
+ max_tamper_area: 1
configs/data_gen/whydra/location/mml-ws01-celeba-hq.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ root_dir: /media/inpainting/CelebA-HQ
4
+ out_dir: /media/inpainting/paper_data/CelebA-HQ_val_test
5
+ extension: jpg
configs/data_gen/whydra/location/mml-ws01-ffhq.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ root_dir: /media/inpainting/FFHQ/
4
+ out_dir: /media/inpainting/paper_data/FFHQ_val
5
+ extension: png
configs/data_gen/whydra/location/mml-ws01-paris.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ root_dir: /media/inpainting/Paris_StreetView_Dataset
4
+ out_dir: /media/inpainting/paper_data/Paris_StreetView_Dataset_val
5
+ extension: png
configs/data_gen/whydra/location/mml7-places.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ root_dir: /data/inpainting/Places365
4
+ out_dir: /data/inpainting/paper_data/Places365_val_test
5
+ extension: jpg