Spaces:
Build error
Build error
AK391
commited on
Commit
·
7788a23
1
Parent(s):
5f7f727
example files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +42 -0
- bin/analyze_errors.py +316 -0
- bin/blur_predicts.py +57 -0
- bin/calc_dataset_stats.py +88 -0
- bin/debug/analyze_overlapping_masks.sh +31 -0
- bin/evaluate_predicts.py +79 -0
- bin/evaluator_example.py +76 -0
- bin/extract_masks.py +63 -0
- bin/filter_sharded_dataset.py +69 -0
- bin/gen_debug_mask_dataset.py +61 -0
- bin/gen_mask_dataset.py +130 -0
- bin/gen_mask_dataset_hydra.py +124 -0
- bin/gen_outpainting_dataset.py +88 -0
- bin/make_checkpoint.py +79 -0
- bin/mask_example.py +14 -0
- bin/paper_runfiles/blur_tests.sh +37 -0
- bin/paper_runfiles/env.sh +8 -0
- bin/paper_runfiles/find_best_checkpoint.py +54 -0
- bin/paper_runfiles/generate_test_celeba-hq.sh +17 -0
- bin/paper_runfiles/generate_test_ffhq.sh +17 -0
- bin/paper_runfiles/generate_test_paris.sh +17 -0
- bin/paper_runfiles/generate_test_paris_256.sh +17 -0
- bin/paper_runfiles/generate_val_test.sh +28 -0
- bin/paper_runfiles/predict_inner_features.sh +20 -0
- bin/paper_runfiles/update_test_data_stats.sh +30 -0
- bin/predict.py +89 -0
- bin/predict_inner_features.py +119 -0
- bin/report_from_tb.py +83 -0
- bin/sample_from_dataset.py +87 -0
- bin/side_by_side.py +76 -0
- bin/split_tar.py +22 -0
- bin/train.py +72 -0
- canvas.png +0 -0
- conda_env.yml +165 -0
- configs/analyze_mask_errors.yaml +7 -0
- configs/data_gen/gen_segm_dataset1.yaml +25 -0
- configs/data_gen/gen_segm_dataset3.yaml +25 -0
- configs/data_gen/random_medium_256.yaml +33 -0
- configs/data_gen/random_medium_512.yaml +33 -0
- configs/data_gen/random_thick_256.yaml +33 -0
- configs/data_gen/random_thick_512.yaml +33 -0
- configs/data_gen/random_thin_256.yaml +25 -0
- configs/data_gen/random_thin_512.yaml +25 -0
- configs/data_gen/segm_256.yaml +27 -0
- configs/data_gen/segm_512.yaml +27 -0
- configs/data_gen/sr_256.yaml +25 -0
- configs/data_gen/whydra/location/mml-ws01-celeba-hq.yaml +5 -0
- configs/data_gen/whydra/location/mml-ws01-ffhq.yaml +5 -0
- configs/data_gen/whydra/location/mml-ws01-paris.yaml +5 -0
- configs/data_gen/whydra/location/mml7-places.yaml +5 -0
app.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.system("gdown https://drive.google.com/uc?id=1-95IOJ-2y9BtmABiffIwndPqNZD_gLnV")
|
3 |
+
os.system("unzip big-lama.zip")
|
4 |
+
import cv2
|
5 |
+
import paddlehub as hub
|
6 |
+
import gradio as gr
|
7 |
+
import torch
|
8 |
+
from PIL import Image, ImageOps
|
9 |
+
import numpy as np
|
10 |
+
os.mkdir("data")
|
11 |
+
os.mkdir("dataout")
|
12 |
+
model = hub.Module(name='U2Net')
|
13 |
+
def infer(img,mask,option):
|
14 |
+
img = ImageOps.contain(img, (700,700))
|
15 |
+
width, height = img.size
|
16 |
+
img.save("./data/data.png")
|
17 |
+
if option == "automatic (U2net)":
|
18 |
+
result = model.Segmentation(
|
19 |
+
images=[cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)],
|
20 |
+
paths=None,
|
21 |
+
batch_size=1,
|
22 |
+
input_size=320,
|
23 |
+
output_dir='output',
|
24 |
+
visualization=True)
|
25 |
+
im = Image.fromarray(result[0]['mask'])
|
26 |
+
else:
|
27 |
+
mask = mask.resize((width,height))
|
28 |
+
im = mask
|
29 |
+
im.save("./data/data_mask.png")
|
30 |
+
os.system('python predict.py model.path=/home/user/app/big-lama/ indir=/home/user/app/data/ outdir=/home/user/app/dataout/ device=cpu')
|
31 |
+
return "./dataout/data_mask.png",im
|
32 |
+
|
33 |
+
inputs = [gr.inputs.Image(type='pil', label="Original Image"),gr.inputs.Image(type='pil',source="canvas", label="Mask",invert_colors=True),gr.inputs.Radio(choices=["automatic (U2net)","manual"], type="value", default="manual", label="Masking option")]
|
34 |
+
outputs = [gr.outputs.Image(type="file",label="output"),gr.outputs.Image(type="pil",label="Mask")]
|
35 |
+
title = "LaMa Image Inpainting"
|
36 |
+
description = "Gradio demo for LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Masks are generated by U^2net"
|
37 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.07161' target='_blank'>Resolution-robust Large Mask Inpainting with Fourier Convolutions</a> | <a href='https://github.com/saic-mdal/lama' target='_blank'>Github Repo</a></p>"
|
38 |
+
examples = [
|
39 |
+
['person512.png',"canvas.png","automatic (U2net)"],
|
40 |
+
['person512.png',"maskexam.png","manual"]
|
41 |
+
]
|
42 |
+
gr.Interface(infer, inputs, outputs, title=title, description=description, article=article, examples=examples).launch(enable_queue=True,cache_examples=True)
|
bin/analyze_errors.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import sklearn
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
+
import pickle
|
8 |
+
import pandas as pd
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from joblib import Parallel, delayed
|
11 |
+
|
12 |
+
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset, load_image
|
13 |
+
from saicinpainting.evaluation.losses.fid.inception import InceptionV3
|
14 |
+
from saicinpainting.evaluation.utils import load_yaml
|
15 |
+
from saicinpainting.training.visualizers.base import visualize_mask_and_images
|
16 |
+
|
17 |
+
|
18 |
+
def draw_score(img, score):
|
19 |
+
img = np.transpose(img, (1, 2, 0))
|
20 |
+
cv2.putText(img, f'{score:.2f}',
|
21 |
+
(40, 40),
|
22 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
23 |
+
1,
|
24 |
+
(0, 1, 0),
|
25 |
+
thickness=3)
|
26 |
+
img = np.transpose(img, (2, 0, 1))
|
27 |
+
return img
|
28 |
+
|
29 |
+
|
30 |
+
def save_global_samples(global_mask_fnames, mask2real_fname, mask2fake_fname, out_dir, real_scores_by_fname, fake_scores_by_fname):
|
31 |
+
for cur_mask_fname in global_mask_fnames:
|
32 |
+
cur_real_fname = mask2real_fname[cur_mask_fname]
|
33 |
+
orig_img = load_image(cur_real_fname, mode='RGB')
|
34 |
+
fake_img = load_image(mask2fake_fname[cur_mask_fname], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
|
35 |
+
mask = load_image(cur_mask_fname, mode='L')[None, ...]
|
36 |
+
|
37 |
+
draw_score(orig_img, real_scores_by_fname.loc[cur_real_fname, 'real_score'])
|
38 |
+
draw_score(fake_img, fake_scores_by_fname.loc[cur_mask_fname, 'fake_score'])
|
39 |
+
|
40 |
+
cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=mask, fake=fake_img),
|
41 |
+
keys=['image', 'fake'],
|
42 |
+
last_without_mask=True)
|
43 |
+
cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8')
|
44 |
+
cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR)
|
45 |
+
cv2.imwrite(os.path.join(out_dir, os.path.splitext(os.path.basename(cur_mask_fname))[0] + '.jpg'),
|
46 |
+
cur_grid)
|
47 |
+
|
48 |
+
|
49 |
+
def save_samples_by_real(worst_best_by_real, mask2fake_fname, fake_info, out_dir):
|
50 |
+
for real_fname in worst_best_by_real.index:
|
51 |
+
worst_mask_path = worst_best_by_real.loc[real_fname, 'worst']
|
52 |
+
best_mask_path = worst_best_by_real.loc[real_fname, 'best']
|
53 |
+
orig_img = load_image(real_fname, mode='RGB')
|
54 |
+
worst_mask_img = load_image(worst_mask_path, mode='L')[None, ...]
|
55 |
+
worst_fake_img = load_image(mask2fake_fname[worst_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
|
56 |
+
best_mask_img = load_image(best_mask_path, mode='L')[None, ...]
|
57 |
+
best_fake_img = load_image(mask2fake_fname[best_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
|
58 |
+
|
59 |
+
draw_score(orig_img, worst_best_by_real.loc[real_fname, 'real_score'])
|
60 |
+
draw_score(worst_fake_img, worst_best_by_real.loc[real_fname, 'worst_score'])
|
61 |
+
draw_score(best_fake_img, worst_best_by_real.loc[real_fname, 'best_score'])
|
62 |
+
|
63 |
+
cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=np.zeros_like(worst_mask_img),
|
64 |
+
worst_mask=worst_mask_img, worst_img=worst_fake_img,
|
65 |
+
best_mask=best_mask_img, best_img=best_fake_img),
|
66 |
+
keys=['image', 'worst_mask', 'worst_img', 'best_mask', 'best_img'],
|
67 |
+
rescale_keys=['worst_mask', 'best_mask'],
|
68 |
+
last_without_mask=True)
|
69 |
+
cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8')
|
70 |
+
cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR)
|
71 |
+
cv2.imwrite(os.path.join(out_dir,
|
72 |
+
os.path.splitext(os.path.basename(real_fname))[0] + '.jpg'),
|
73 |
+
cur_grid)
|
74 |
+
|
75 |
+
fig, (ax1, ax2) = plt.subplots(1, 2)
|
76 |
+
cur_stat = fake_info[fake_info['real_fname'] == real_fname]
|
77 |
+
cur_stat['fake_score'].hist(ax=ax1)
|
78 |
+
cur_stat['real_score'].hist(ax=ax2)
|
79 |
+
fig.tight_layout()
|
80 |
+
fig.savefig(os.path.join(out_dir,
|
81 |
+
os.path.splitext(os.path.basename(real_fname))[0] + '_scores.png'))
|
82 |
+
plt.close(fig)
|
83 |
+
|
84 |
+
|
85 |
+
def extract_overlapping_masks(mask_fnames, cur_i, fake_scores_table, max_overlaps_n=2):
|
86 |
+
result_pairs = []
|
87 |
+
result_scores = []
|
88 |
+
mask_fname_a = mask_fnames[cur_i]
|
89 |
+
mask_a = load_image(mask_fname_a, mode='L')[None, ...] > 0.5
|
90 |
+
cur_score_a = fake_scores_table.loc[mask_fname_a, 'fake_score']
|
91 |
+
for mask_fname_b in mask_fnames[cur_i + 1:]:
|
92 |
+
mask_b = load_image(mask_fname_b, mode='L')[None, ...] > 0.5
|
93 |
+
if not np.any(mask_a & mask_b):
|
94 |
+
continue
|
95 |
+
cur_score_b = fake_scores_table.loc[mask_fname_b, 'fake_score']
|
96 |
+
result_pairs.append((mask_fname_a, mask_fname_b))
|
97 |
+
result_scores.append(cur_score_b - cur_score_a)
|
98 |
+
if len(result_pairs) >= max_overlaps_n:
|
99 |
+
break
|
100 |
+
return result_pairs, result_scores
|
101 |
+
|
102 |
+
|
103 |
+
def main(args):
|
104 |
+
config = load_yaml(args.config)
|
105 |
+
|
106 |
+
latents_dir = os.path.join(args.outpath, 'latents')
|
107 |
+
os.makedirs(latents_dir, exist_ok=True)
|
108 |
+
global_worst_dir = os.path.join(args.outpath, 'global_worst')
|
109 |
+
os.makedirs(global_worst_dir, exist_ok=True)
|
110 |
+
global_best_dir = os.path.join(args.outpath, 'global_best')
|
111 |
+
os.makedirs(global_best_dir, exist_ok=True)
|
112 |
+
worst_best_by_best_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_max')
|
113 |
+
os.makedirs(worst_best_by_best_worst_score_diff_max_dir, exist_ok=True)
|
114 |
+
worst_best_by_best_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_min')
|
115 |
+
os.makedirs(worst_best_by_best_worst_score_diff_min_dir, exist_ok=True)
|
116 |
+
worst_best_by_real_best_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_max')
|
117 |
+
os.makedirs(worst_best_by_real_best_score_diff_max_dir, exist_ok=True)
|
118 |
+
worst_best_by_real_best_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_min')
|
119 |
+
os.makedirs(worst_best_by_real_best_score_diff_min_dir, exist_ok=True)
|
120 |
+
worst_best_by_real_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_max')
|
121 |
+
os.makedirs(worst_best_by_real_worst_score_diff_max_dir, exist_ok=True)
|
122 |
+
worst_best_by_real_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_min')
|
123 |
+
os.makedirs(worst_best_by_real_worst_score_diff_min_dir, exist_ok=True)
|
124 |
+
|
125 |
+
if not args.only_report:
|
126 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
|
127 |
+
inception_model = InceptionV3([block_idx]).eval().cuda()
|
128 |
+
|
129 |
+
dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
|
130 |
+
|
131 |
+
real2vector_cache = {}
|
132 |
+
|
133 |
+
real_features = []
|
134 |
+
fake_features = []
|
135 |
+
|
136 |
+
orig_fnames = []
|
137 |
+
mask_fnames = []
|
138 |
+
mask2real_fname = {}
|
139 |
+
mask2fake_fname = {}
|
140 |
+
|
141 |
+
for batch_i, batch in enumerate(dataset):
|
142 |
+
orig_img_fname = dataset.img_filenames[batch_i]
|
143 |
+
mask_fname = dataset.mask_filenames[batch_i]
|
144 |
+
fake_fname = dataset.pred_filenames[batch_i]
|
145 |
+
mask2real_fname[mask_fname] = orig_img_fname
|
146 |
+
mask2fake_fname[mask_fname] = fake_fname
|
147 |
+
|
148 |
+
cur_real_vector = real2vector_cache.get(orig_img_fname, None)
|
149 |
+
if cur_real_vector is None:
|
150 |
+
with torch.no_grad():
|
151 |
+
in_img = torch.from_numpy(batch['image'][None, ...]).cuda()
|
152 |
+
cur_real_vector = inception_model(in_img)[0].squeeze(-1).squeeze(-1).cpu().numpy()
|
153 |
+
real2vector_cache[orig_img_fname] = cur_real_vector
|
154 |
+
|
155 |
+
pred_img = torch.from_numpy(batch['inpainted'][None, ...]).cuda()
|
156 |
+
cur_fake_vector = inception_model(pred_img)[0].squeeze(-1).squeeze(-1).cpu().numpy()
|
157 |
+
|
158 |
+
real_features.append(cur_real_vector)
|
159 |
+
fake_features.append(cur_fake_vector)
|
160 |
+
|
161 |
+
orig_fnames.append(orig_img_fname)
|
162 |
+
mask_fnames.append(mask_fname)
|
163 |
+
|
164 |
+
ids_features = np.concatenate(real_features + fake_features, axis=0)
|
165 |
+
ids_labels = np.array(([1] * len(real_features)) + ([0] * len(fake_features)))
|
166 |
+
|
167 |
+
with open(os.path.join(latents_dir, 'featues.pkl'), 'wb') as f:
|
168 |
+
pickle.dump(ids_features, f, protocol=3)
|
169 |
+
with open(os.path.join(latents_dir, 'labels.pkl'), 'wb') as f:
|
170 |
+
pickle.dump(ids_labels, f, protocol=3)
|
171 |
+
with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'wb') as f:
|
172 |
+
pickle.dump(orig_fnames, f, protocol=3)
|
173 |
+
with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'wb') as f:
|
174 |
+
pickle.dump(mask_fnames, f, protocol=3)
|
175 |
+
with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'wb') as f:
|
176 |
+
pickle.dump(mask2real_fname, f, protocol=3)
|
177 |
+
with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'wb') as f:
|
178 |
+
pickle.dump(mask2fake_fname, f, protocol=3)
|
179 |
+
|
180 |
+
svm = sklearn.svm.LinearSVC(dual=False)
|
181 |
+
svm.fit(ids_features, ids_labels)
|
182 |
+
|
183 |
+
pred_scores = svm.decision_function(ids_features)
|
184 |
+
real_scores = pred_scores[:len(real_features)]
|
185 |
+
fake_scores = pred_scores[len(real_features):]
|
186 |
+
|
187 |
+
with open(os.path.join(latents_dir, 'pred_scores.pkl'), 'wb') as f:
|
188 |
+
pickle.dump(pred_scores, f, protocol=3)
|
189 |
+
with open(os.path.join(latents_dir, 'real_scores.pkl'), 'wb') as f:
|
190 |
+
pickle.dump(real_scores, f, protocol=3)
|
191 |
+
with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'wb') as f:
|
192 |
+
pickle.dump(fake_scores, f, protocol=3)
|
193 |
+
else:
|
194 |
+
with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'rb') as f:
|
195 |
+
orig_fnames = pickle.load(f)
|
196 |
+
with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'rb') as f:
|
197 |
+
mask_fnames = pickle.load(f)
|
198 |
+
with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'rb') as f:
|
199 |
+
mask2real_fname = pickle.load(f)
|
200 |
+
with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'rb') as f:
|
201 |
+
mask2fake_fname = pickle.load(f)
|
202 |
+
with open(os.path.join(latents_dir, 'real_scores.pkl'), 'rb') as f:
|
203 |
+
real_scores = pickle.load(f)
|
204 |
+
with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'rb') as f:
|
205 |
+
fake_scores = pickle.load(f)
|
206 |
+
|
207 |
+
real_info = pd.DataFrame(data=[dict(real_fname=fname,
|
208 |
+
real_score=score)
|
209 |
+
for fname, score
|
210 |
+
in zip(orig_fnames, real_scores)])
|
211 |
+
real_info.set_index('real_fname', drop=True, inplace=True)
|
212 |
+
|
213 |
+
fake_info = pd.DataFrame(data=[dict(mask_fname=fname,
|
214 |
+
fake_fname=mask2fake_fname[fname],
|
215 |
+
real_fname=mask2real_fname[fname],
|
216 |
+
fake_score=score)
|
217 |
+
for fname, score
|
218 |
+
in zip(mask_fnames, fake_scores)])
|
219 |
+
fake_info = fake_info.join(real_info, on='real_fname', how='left')
|
220 |
+
fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True)
|
221 |
+
|
222 |
+
fake_stats_by_real = fake_info.groupby('real_fname')['fake_score'].describe()[['mean', 'std']].rename(
|
223 |
+
{'mean': 'mean_fake_by_real', 'std': 'std_fake_by_real'}, axis=1)
|
224 |
+
fake_info = fake_info.join(fake_stats_by_real, on='real_fname', rsuffix='stat_by_real')
|
225 |
+
fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True)
|
226 |
+
fake_info.to_csv(os.path.join(latents_dir, 'join_scores_table.csv'), sep='\t', index=False)
|
227 |
+
|
228 |
+
fake_scores_table = fake_info.set_index('mask_fname')['fake_score'].to_frame()
|
229 |
+
real_scores_table = fake_info.set_index('real_fname')['real_score'].drop_duplicates().to_frame()
|
230 |
+
|
231 |
+
fig, (ax1, ax2) = plt.subplots(1, 2)
|
232 |
+
ax1.hist(fake_scores)
|
233 |
+
ax2.hist(real_scores)
|
234 |
+
fig.tight_layout()
|
235 |
+
fig.savefig(os.path.join(args.outpath, 'global_scores_hist.png'))
|
236 |
+
plt.close(fig)
|
237 |
+
|
238 |
+
global_worst_masks = fake_info.sort_values('fake_score', ascending=True)['mask_fname'].iloc[:config.take_global_top].to_list()
|
239 |
+
global_best_masks = fake_info.sort_values('fake_score', ascending=False)['mask_fname'].iloc[:config.take_global_top].to_list()
|
240 |
+
save_global_samples(global_worst_masks, mask2real_fname, mask2fake_fname, global_worst_dir, real_scores_table, fake_scores_table)
|
241 |
+
save_global_samples(global_best_masks, mask2real_fname, mask2fake_fname, global_best_dir, real_scores_table, fake_scores_table)
|
242 |
+
|
243 |
+
# grouped by real
|
244 |
+
worst_samples_by_real = fake_info.groupby('real_fname').apply(
|
245 |
+
lambda d: d.set_index('mask_fname')['fake_score'].idxmin()).to_frame().rename({0: 'worst'}, axis=1)
|
246 |
+
best_samples_by_real = fake_info.groupby('real_fname').apply(
|
247 |
+
lambda d: d.set_index('mask_fname')['fake_score'].idxmax()).to_frame().rename({0: 'best'}, axis=1)
|
248 |
+
worst_best_by_real = pd.concat([worst_samples_by_real, best_samples_by_real], axis=1)
|
249 |
+
|
250 |
+
worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'worst_score'}, axis=1),
|
251 |
+
on='worst')
|
252 |
+
worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'best_score'}, axis=1),
|
253 |
+
on='best')
|
254 |
+
worst_best_by_real = worst_best_by_real.join(real_scores_table)
|
255 |
+
|
256 |
+
worst_best_by_real['best_worst_score_diff'] = worst_best_by_real['best_score'] - worst_best_by_real['worst_score']
|
257 |
+
worst_best_by_real['real_best_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['best_score']
|
258 |
+
worst_best_by_real['real_worst_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['worst_score']
|
259 |
+
|
260 |
+
worst_best_by_best_worst_score_diff_min = worst_best_by_real.sort_values('best_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top]
|
261 |
+
worst_best_by_best_worst_score_diff_max = worst_best_by_real.sort_values('best_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top]
|
262 |
+
save_samples_by_real(worst_best_by_best_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_min_dir)
|
263 |
+
save_samples_by_real(worst_best_by_best_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_max_dir)
|
264 |
+
|
265 |
+
worst_best_by_real_best_score_diff_min = worst_best_by_real.sort_values('real_best_score_diff', ascending=True).iloc[:config.take_worst_best_top]
|
266 |
+
worst_best_by_real_best_score_diff_max = worst_best_by_real.sort_values('real_best_score_diff', ascending=False).iloc[:config.take_worst_best_top]
|
267 |
+
save_samples_by_real(worst_best_by_real_best_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_min_dir)
|
268 |
+
save_samples_by_real(worst_best_by_real_best_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_max_dir)
|
269 |
+
|
270 |
+
worst_best_by_real_worst_score_diff_min = worst_best_by_real.sort_values('real_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top]
|
271 |
+
worst_best_by_real_worst_score_diff_max = worst_best_by_real.sort_values('real_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top]
|
272 |
+
save_samples_by_real(worst_best_by_real_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_min_dir)
|
273 |
+
save_samples_by_real(worst_best_by_real_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_max_dir)
|
274 |
+
|
275 |
+
# analyze what change of mask causes bigger change of score
|
276 |
+
overlapping_mask_fname_pairs = []
|
277 |
+
overlapping_mask_fname_score_diffs = []
|
278 |
+
for cur_real_fname in orig_fnames:
|
279 |
+
cur_fakes_info = fake_info[fake_info['real_fname'] == cur_real_fname]
|
280 |
+
cur_mask_fnames = sorted(cur_fakes_info['mask_fname'].unique())
|
281 |
+
|
282 |
+
cur_mask_pairs_and_scores = Parallel(args.n_jobs)(
|
283 |
+
delayed(extract_overlapping_masks)(cur_mask_fnames, i, fake_scores_table)
|
284 |
+
for i in range(len(cur_mask_fnames) - 1)
|
285 |
+
)
|
286 |
+
for cur_pairs, cur_scores in cur_mask_pairs_and_scores:
|
287 |
+
overlapping_mask_fname_pairs.extend(cur_pairs)
|
288 |
+
overlapping_mask_fname_score_diffs.extend(cur_scores)
|
289 |
+
|
290 |
+
overlapping_mask_fname_pairs = np.asarray(overlapping_mask_fname_pairs)
|
291 |
+
overlapping_mask_fname_score_diffs = np.asarray(overlapping_mask_fname_score_diffs)
|
292 |
+
overlapping_sort_idx = np.argsort(overlapping_mask_fname_score_diffs)
|
293 |
+
overlapping_mask_fname_pairs = overlapping_mask_fname_pairs[overlapping_sort_idx]
|
294 |
+
overlapping_mask_fname_score_diffs = overlapping_mask_fname_score_diffs[overlapping_sort_idx]
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
if __name__ == '__main__':
|
302 |
+
import argparse
|
303 |
+
|
304 |
+
aparser = argparse.ArgumentParser()
|
305 |
+
aparser.add_argument('config', type=str, help='Path to config for dataset generation')
|
306 |
+
aparser.add_argument('datadir', type=str,
|
307 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
308 |
+
aparser.add_argument('predictdir', type=str,
|
309 |
+
help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
|
310 |
+
aparser.add_argument('outpath', type=str, help='Where to put results')
|
311 |
+
aparser.add_argument('--only-report', action='store_true',
|
312 |
+
help='Whether to skip prediction and feature extraction, '
|
313 |
+
'load all the possible latents and proceed with report only')
|
314 |
+
aparser.add_argument('--n-jobs', type=int, default=8, help='how many processes to use for pair mask mining')
|
315 |
+
|
316 |
+
main(aparser.parse_args())
|
bin/blur_predicts.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import tqdm
|
8 |
+
|
9 |
+
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
|
10 |
+
from saicinpainting.evaluation.utils import load_yaml
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
config = load_yaml(args.config)
|
15 |
+
|
16 |
+
if not args.predictdir.endswith('/'):
|
17 |
+
args.predictdir += '/'
|
18 |
+
|
19 |
+
dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
|
20 |
+
|
21 |
+
os.makedirs(os.path.dirname(args.outpath), exist_ok=True)
|
22 |
+
|
23 |
+
for img_i in tqdm.trange(len(dataset)):
|
24 |
+
pred_fname = dataset.pred_filenames[img_i]
|
25 |
+
cur_out_fname = os.path.join(args.outpath, pred_fname[len(args.predictdir):])
|
26 |
+
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
|
27 |
+
|
28 |
+
sample = dataset[img_i]
|
29 |
+
img = sample['image']
|
30 |
+
mask = sample['mask']
|
31 |
+
inpainted = sample['inpainted']
|
32 |
+
|
33 |
+
inpainted_blurred = cv2.GaussianBlur(np.transpose(inpainted, (1, 2, 0)),
|
34 |
+
ksize=(args.k, args.k),
|
35 |
+
sigmaX=args.s, sigmaY=args.s,
|
36 |
+
borderType=cv2.BORDER_REFLECT)
|
37 |
+
|
38 |
+
cur_res = (1 - mask) * np.transpose(img, (1, 2, 0)) + mask * inpainted_blurred
|
39 |
+
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
|
40 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
41 |
+
cv2.imwrite(cur_out_fname, cur_res)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == '__main__':
|
45 |
+
import argparse
|
46 |
+
|
47 |
+
aparser = argparse.ArgumentParser()
|
48 |
+
aparser.add_argument('config', type=str, help='Path to evaluation config')
|
49 |
+
aparser.add_argument('datadir', type=str,
|
50 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
51 |
+
aparser.add_argument('predictdir', type=str,
|
52 |
+
help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
|
53 |
+
aparser.add_argument('outpath', type=str, help='Where to put results')
|
54 |
+
aparser.add_argument('-s', type=float, default=0.1, help='Gaussian blur sigma')
|
55 |
+
aparser.add_argument('-k', type=int, default=5, help='Kernel size in gaussian blur')
|
56 |
+
|
57 |
+
main(aparser.parse_args())
|
bin/calc_dataset_stats.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import tqdm
|
7 |
+
from scipy.ndimage.morphology import distance_transform_edt
|
8 |
+
|
9 |
+
from saicinpainting.evaluation.data import InpaintingDataset
|
10 |
+
from saicinpainting.evaluation.vis import save_item_for_vis
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
dataset = InpaintingDataset(args.datadir, img_suffix='.png')
|
15 |
+
|
16 |
+
area_bins = np.linspace(0, 1, args.area_bins + 1)
|
17 |
+
|
18 |
+
heights = []
|
19 |
+
widths = []
|
20 |
+
image_areas = []
|
21 |
+
hole_areas = []
|
22 |
+
hole_area_percents = []
|
23 |
+
known_pixel_distances = []
|
24 |
+
|
25 |
+
area_bins_count = np.zeros(args.area_bins)
|
26 |
+
area_bin_titles = [f'{area_bins[i] * 100:.0f}-{area_bins[i + 1] * 100:.0f}' for i in range(args.area_bins)]
|
27 |
+
|
28 |
+
bin2i = [[] for _ in range(args.area_bins)]
|
29 |
+
|
30 |
+
for i, item in enumerate(tqdm.tqdm(dataset)):
|
31 |
+
h, w = item['image'].shape[1:]
|
32 |
+
heights.append(h)
|
33 |
+
widths.append(w)
|
34 |
+
full_area = h * w
|
35 |
+
image_areas.append(full_area)
|
36 |
+
bin_mask = item['mask'] > 0.5
|
37 |
+
hole_area = bin_mask.sum()
|
38 |
+
hole_areas.append(hole_area)
|
39 |
+
hole_percent = hole_area / full_area
|
40 |
+
hole_area_percents.append(hole_percent)
|
41 |
+
bin_i = np.clip(np.searchsorted(area_bins, hole_percent) - 1, 0, len(area_bins_count) - 1)
|
42 |
+
area_bins_count[bin_i] += 1
|
43 |
+
bin2i[bin_i].append(i)
|
44 |
+
|
45 |
+
cur_dist = distance_transform_edt(bin_mask)
|
46 |
+
cur_dist_inside_mask = cur_dist[bin_mask]
|
47 |
+
known_pixel_distances.append(cur_dist_inside_mask.mean())
|
48 |
+
|
49 |
+
os.makedirs(args.outdir, exist_ok=True)
|
50 |
+
with open(os.path.join(args.outdir, 'summary.txt'), 'w') as f:
|
51 |
+
f.write(f'''Location: {args.datadir}
|
52 |
+
|
53 |
+
Number of samples: {len(dataset)}
|
54 |
+
|
55 |
+
Image height: min {min(heights):5d} max {max(heights):5d} mean {np.mean(heights):.2f}
|
56 |
+
Image width: min {min(widths):5d} max {max(widths):5d} mean {np.mean(widths):.2f}
|
57 |
+
Image area: min {min(image_areas):7d} max {max(image_areas):7d} mean {np.mean(image_areas):.2f}
|
58 |
+
Hole area: min {min(hole_areas):7d} max {max(hole_areas):7d} mean {np.mean(hole_areas):.2f}
|
59 |
+
Hole area %: min {min(hole_area_percents) * 100:2.2f} max {max(hole_area_percents) * 100:2.2f} mean {np.mean(hole_area_percents) * 100:2.2f}
|
60 |
+
Dist 2known: min {min(known_pixel_distances):2.2f} max {max(known_pixel_distances):2.2f} mean {np.mean(known_pixel_distances):2.2f} median {np.median(known_pixel_distances):2.2f}
|
61 |
+
|
62 |
+
Stats by hole area %:
|
63 |
+
''')
|
64 |
+
for bin_i in range(args.area_bins):
|
65 |
+
f.write(f'{area_bin_titles[bin_i]}%: '
|
66 |
+
f'samples number {area_bins_count[bin_i]}, '
|
67 |
+
f'{area_bins_count[bin_i] / len(dataset) * 100:.1f}%\n')
|
68 |
+
|
69 |
+
for bin_i in range(args.area_bins):
|
70 |
+
bindir = os.path.join(args.outdir, 'samples', area_bin_titles[bin_i])
|
71 |
+
os.makedirs(bindir, exist_ok=True)
|
72 |
+
bin_idx = bin2i[bin_i]
|
73 |
+
for sample_i in np.random.choice(bin_idx, size=min(len(bin_idx), args.samples_n), replace=False):
|
74 |
+
save_item_for_vis(dataset[sample_i], os.path.join(bindir, f'{sample_i}.png'))
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
import argparse
|
79 |
+
|
80 |
+
aparser = argparse.ArgumentParser()
|
81 |
+
aparser.add_argument('datadir', type=str,
|
82 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
83 |
+
aparser.add_argument('outdir', type=str, help='Where to put results')
|
84 |
+
aparser.add_argument('--samples-n', type=int, default=10,
|
85 |
+
help='Number of sample images with masks to copy for visualization for each area bin')
|
86 |
+
aparser.add_argument('--area-bins', type=int, default=10, help='How many area bins to have')
|
87 |
+
|
88 |
+
main(aparser.parse_args())
|
bin/debug/analyze_overlapping_masks.sh
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
BASEDIR="$(dirname $0)"
|
4 |
+
|
5 |
+
# paths are valid for mml7
|
6 |
+
|
7 |
+
# select images
|
8 |
+
#ls /data/inpainting/work/data/train | shuf | head -2000 | xargs -n1 -I{} cp {} /data/inpainting/mask_analysis/src
|
9 |
+
|
10 |
+
# generate masks
|
11 |
+
#"$BASEDIR/../gen_debug_mask_dataset.py" \
|
12 |
+
# "$BASEDIR/../../configs/debug_mask_gen.yaml" \
|
13 |
+
# "/data/inpainting/mask_analysis/src" \
|
14 |
+
# "/data/inpainting/mask_analysis/generated"
|
15 |
+
|
16 |
+
# predict
|
17 |
+
#"$BASEDIR/../predict.py" \
|
18 |
+
# model.path="simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/saved_checkpoint/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15_epoch22-step-574999" \
|
19 |
+
# indir="/data/inpainting/mask_analysis/generated" \
|
20 |
+
# outdir="/data/inpainting/mask_analysis/predicted" \
|
21 |
+
# dataset.img_suffix=.jpg \
|
22 |
+
# +out_ext=.jpg
|
23 |
+
|
24 |
+
# analyze good and bad samples
|
25 |
+
"$BASEDIR/../analyze_errors.py" \
|
26 |
+
--only-report \
|
27 |
+
--n-jobs 8 \
|
28 |
+
"$BASEDIR/../../configs/analyze_mask_errors.yaml" \
|
29 |
+
"/data/inpainting/mask_analysis/small/generated" \
|
30 |
+
"/data/inpainting/mask_analysis/small/predicted" \
|
31 |
+
"/data/inpainting/mask_analysis/small/report"
|
bin/evaluate_predicts.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
|
8 |
+
from saicinpainting.evaluation.evaluator import InpaintingEvaluator, lpips_fid100_f1
|
9 |
+
from saicinpainting.evaluation.losses.base_loss import SegmentationAwareSSIM, \
|
10 |
+
SegmentationClassStats, SSIMScore, LPIPSScore, FIDScore, SegmentationAwareLPIPS, SegmentationAwareFID
|
11 |
+
from saicinpainting.evaluation.utils import load_yaml
|
12 |
+
|
13 |
+
|
14 |
+
def main(args):
|
15 |
+
config = load_yaml(args.config)
|
16 |
+
|
17 |
+
dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
|
18 |
+
|
19 |
+
metrics = {
|
20 |
+
'ssim': SSIMScore(),
|
21 |
+
'lpips': LPIPSScore(),
|
22 |
+
'fid': FIDScore()
|
23 |
+
}
|
24 |
+
enable_segm = config.get('segmentation', dict(enable=False)).get('enable', False)
|
25 |
+
if enable_segm:
|
26 |
+
weights_path = os.path.expandvars(config.segmentation.weights_path)
|
27 |
+
metrics.update(dict(
|
28 |
+
segm_stats=SegmentationClassStats(weights_path=weights_path),
|
29 |
+
segm_ssim=SegmentationAwareSSIM(weights_path=weights_path),
|
30 |
+
segm_lpips=SegmentationAwareLPIPS(weights_path=weights_path),
|
31 |
+
segm_fid=SegmentationAwareFID(weights_path=weights_path)
|
32 |
+
))
|
33 |
+
evaluator = InpaintingEvaluator(dataset, scores=metrics,
|
34 |
+
integral_title='lpips_fid100_f1', integral_func=lpips_fid100_f1,
|
35 |
+
**config.evaluator_kwargs)
|
36 |
+
|
37 |
+
os.makedirs(os.path.dirname(args.outpath), exist_ok=True)
|
38 |
+
|
39 |
+
results = evaluator.evaluate()
|
40 |
+
|
41 |
+
results = pd.DataFrame(results).stack(1).unstack(0)
|
42 |
+
results.dropna(axis=1, how='all', inplace=True)
|
43 |
+
results.to_csv(args.outpath, sep='\t', float_format='%.4f')
|
44 |
+
|
45 |
+
if enable_segm:
|
46 |
+
only_short_results = results[[c for c in results.columns if not c[0].startswith('segm_')]].dropna(axis=1, how='all')
|
47 |
+
only_short_results.to_csv(args.outpath + '_short', sep='\t', float_format='%.4f')
|
48 |
+
|
49 |
+
print(only_short_results)
|
50 |
+
|
51 |
+
segm_metrics_results = results[['segm_ssim', 'segm_lpips', 'segm_fid']].dropna(axis=1, how='all').transpose().unstack(0).reorder_levels([1, 0], axis=1)
|
52 |
+
segm_metrics_results.drop(['mean', 'std'], axis=0, inplace=True)
|
53 |
+
|
54 |
+
segm_stats_results = results['segm_stats'].dropna(axis=1, how='all').transpose()
|
55 |
+
segm_stats_results.index = pd.MultiIndex.from_tuples(n.split('/') for n in segm_stats_results.index)
|
56 |
+
segm_stats_results = segm_stats_results.unstack(0).reorder_levels([1, 0], axis=1)
|
57 |
+
segm_stats_results.sort_index(axis=1, inplace=True)
|
58 |
+
segm_stats_results.dropna(axis=0, how='all', inplace=True)
|
59 |
+
|
60 |
+
segm_results = pd.concat([segm_metrics_results, segm_stats_results], axis=1, sort=True)
|
61 |
+
segm_results.sort_values(('mask_freq', 'total'), ascending=False, inplace=True)
|
62 |
+
|
63 |
+
segm_results.to_csv(args.outpath + '_segm', sep='\t', float_format='%.4f')
|
64 |
+
else:
|
65 |
+
print(results)
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == '__main__':
|
69 |
+
import argparse
|
70 |
+
|
71 |
+
aparser = argparse.ArgumentParser()
|
72 |
+
aparser.add_argument('config', type=str, help='Path to evaluation config')
|
73 |
+
aparser.add_argument('datadir', type=str,
|
74 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
75 |
+
aparser.add_argument('predictdir', type=str,
|
76 |
+
help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
|
77 |
+
aparser.add_argument('outpath', type=str, help='Where to put results')
|
78 |
+
|
79 |
+
main(aparser.parse_args())
|
bin/evaluator_example.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from skimage import io
|
7 |
+
from skimage.transform import resize
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
from saicinpainting.evaluation.evaluator import InpaintingEvaluator
|
11 |
+
from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
|
12 |
+
|
13 |
+
|
14 |
+
class SimpleImageDataset(Dataset):
|
15 |
+
def __init__(self, root_dir, image_size=(400, 600)):
|
16 |
+
self.root_dir = root_dir
|
17 |
+
self.files = sorted(os.listdir(root_dir))
|
18 |
+
self.image_size = image_size
|
19 |
+
|
20 |
+
def __getitem__(self, index):
|
21 |
+
img_name = os.path.join(self.root_dir, self.files[index])
|
22 |
+
image = io.imread(img_name)
|
23 |
+
image = resize(image, self.image_size, anti_aliasing=True)
|
24 |
+
image = torch.FloatTensor(image).permute(2, 0, 1)
|
25 |
+
return image
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.files)
|
29 |
+
|
30 |
+
|
31 |
+
def create_rectangle_mask(height, width):
|
32 |
+
mask = np.ones((height, width))
|
33 |
+
up_left_corner = width // 4, height // 4
|
34 |
+
down_right_corner = (width - up_left_corner[0] - 1, height - up_left_corner[1] - 1)
|
35 |
+
cv2.rectangle(mask, up_left_corner, down_right_corner, (0, 0, 0), thickness=cv2.FILLED)
|
36 |
+
return mask
|
37 |
+
|
38 |
+
|
39 |
+
class Model():
|
40 |
+
def __call__(self, img_batch, mask_batch):
|
41 |
+
mean = (img_batch * mask_batch[:, None, :, :]).sum(dim=(2, 3)) / mask_batch.sum(dim=(1, 2))[:, None]
|
42 |
+
inpainted = mean[:, :, None, None] * (1 - mask_batch[:, None, :, :]) + img_batch * mask_batch[:, None, :, :]
|
43 |
+
return inpainted
|
44 |
+
|
45 |
+
|
46 |
+
class SimpleImageSquareMaskDataset(Dataset):
|
47 |
+
def __init__(self, dataset):
|
48 |
+
self.dataset = dataset
|
49 |
+
self.mask = torch.FloatTensor(create_rectangle_mask(*self.dataset.image_size))
|
50 |
+
self.model = Model()
|
51 |
+
|
52 |
+
def __getitem__(self, index):
|
53 |
+
img = self.dataset[index]
|
54 |
+
mask = self.mask.clone()
|
55 |
+
inpainted = self.model(img[None, ...], mask[None, ...])
|
56 |
+
return dict(image=img, mask=mask, inpainted=inpainted)
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.dataset)
|
60 |
+
|
61 |
+
|
62 |
+
dataset = SimpleImageDataset('imgs')
|
63 |
+
mask_dataset = SimpleImageSquareMaskDataset(dataset)
|
64 |
+
model = Model()
|
65 |
+
metrics = {
|
66 |
+
'ssim': SSIMScore(),
|
67 |
+
'lpips': LPIPSScore(),
|
68 |
+
'fid': FIDScore()
|
69 |
+
}
|
70 |
+
|
71 |
+
evaluator = InpaintingEvaluator(
|
72 |
+
mask_dataset, scores=metrics, batch_size=3, area_grouping=True
|
73 |
+
)
|
74 |
+
|
75 |
+
results = evaluator.evaluate(model)
|
76 |
+
print(results)
|
bin/extract_masks.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL.Image as Image
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
def main(args):
|
7 |
+
if not args.indir.endswith('/'):
|
8 |
+
args.indir += '/'
|
9 |
+
os.makedirs(args.outdir, exist_ok=True)
|
10 |
+
|
11 |
+
src_images = [
|
12 |
+
args.indir+fname for fname in os.listdir(args.indir)]
|
13 |
+
|
14 |
+
tgt_masks = [
|
15 |
+
args.outdir+fname[:-4] + f'_mask000.png'
|
16 |
+
for fname in os.listdir(args.indir)]
|
17 |
+
|
18 |
+
for img_name, msk_name in zip(src_images, tgt_masks):
|
19 |
+
#print(img)
|
20 |
+
#print(msk)
|
21 |
+
|
22 |
+
image = Image.open(img_name).convert('RGB')
|
23 |
+
image = np.transpose(np.array(image), (2, 0, 1))
|
24 |
+
|
25 |
+
mask = (image == 255).astype(int)
|
26 |
+
|
27 |
+
print(mask.dtype, mask.shape)
|
28 |
+
|
29 |
+
|
30 |
+
Image.fromarray(
|
31 |
+
np.clip(mask[0,:,:] * 255, 0, 255).astype('uint8'),mode='L'
|
32 |
+
).save(msk_name)
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
'''
|
38 |
+
for infile in src_images:
|
39 |
+
try:
|
40 |
+
file_relpath = infile[len(indir):]
|
41 |
+
img_outpath = os.path.join(outdir, file_relpath)
|
42 |
+
os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
|
43 |
+
|
44 |
+
image = Image.open(infile).convert('RGB')
|
45 |
+
|
46 |
+
mask =
|
47 |
+
|
48 |
+
Image.fromarray(
|
49 |
+
np.clip(
|
50 |
+
cur_mask * 255, 0, 255).astype('uint8'),
|
51 |
+
mode='L'
|
52 |
+
).save(cur_basename + f'_mask{i:03d}.png')
|
53 |
+
'''
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
import argparse
|
59 |
+
aparser = argparse.ArgumentParser()
|
60 |
+
aparser.add_argument('--indir', type=str, help='Path to folder with images')
|
61 |
+
aparser.add_argument('--outdir', type=str, help='Path to folder to store aligned images and masks to')
|
62 |
+
|
63 |
+
main(aparser.parse_args())
|
bin/filter_sharded_dataset.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
|
8 |
+
import braceexpand
|
9 |
+
import webdataset as wds
|
10 |
+
|
11 |
+
DEFAULT_CATS_FILE = os.path.join(os.path.dirname(__file__), '..', 'configs', 'places2-categories_157.txt')
|
12 |
+
|
13 |
+
def is_good_key(key, cats):
|
14 |
+
return any(c in key for c in cats)
|
15 |
+
|
16 |
+
|
17 |
+
def main(args):
|
18 |
+
if args.categories == 'nofilter':
|
19 |
+
good_categories = None
|
20 |
+
else:
|
21 |
+
with open(args.categories, 'r') as f:
|
22 |
+
good_categories = set(line.strip().split(' ')[0] for line in f if line.strip())
|
23 |
+
|
24 |
+
all_input_files = list(braceexpand.braceexpand(args.infile))
|
25 |
+
chunk_size = int(math.ceil(len(all_input_files) / args.n_read_streams))
|
26 |
+
|
27 |
+
input_iterators = [iter(wds.Dataset(all_input_files[start : start + chunk_size]).shuffle(args.shuffle_buffer))
|
28 |
+
for start in range(0, len(all_input_files), chunk_size)]
|
29 |
+
output_datasets = [wds.ShardWriter(args.outpattern.format(i)) for i in range(args.n_write_streams)]
|
30 |
+
|
31 |
+
good_readers = list(range(len(input_iterators)))
|
32 |
+
step_i = 0
|
33 |
+
good_samples = 0
|
34 |
+
bad_samples = 0
|
35 |
+
while len(good_readers) > 0:
|
36 |
+
if step_i % args.print_freq == 0:
|
37 |
+
print(f'Iterations done {step_i}; readers alive {good_readers}; good samples {good_samples}; bad samples {bad_samples}')
|
38 |
+
|
39 |
+
step_i += 1
|
40 |
+
|
41 |
+
ri = random.choice(good_readers)
|
42 |
+
try:
|
43 |
+
sample = next(input_iterators[ri])
|
44 |
+
except StopIteration:
|
45 |
+
good_readers = list(set(good_readers) - {ri})
|
46 |
+
continue
|
47 |
+
|
48 |
+
if good_categories is not None and not is_good_key(sample['__key__'], good_categories):
|
49 |
+
bad_samples += 1
|
50 |
+
continue
|
51 |
+
|
52 |
+
wi = random.randint(0, args.n_write_streams - 1)
|
53 |
+
output_datasets[wi].write(sample)
|
54 |
+
good_samples += 1
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
import argparse
|
59 |
+
|
60 |
+
aparser = argparse.ArgumentParser()
|
61 |
+
aparser.add_argument('--categories', type=str, default=DEFAULT_CATS_FILE)
|
62 |
+
aparser.add_argument('--shuffle-buffer', type=int, default=10000)
|
63 |
+
aparser.add_argument('--n-read-streams', type=int, default=10)
|
64 |
+
aparser.add_argument('--n-write-streams', type=int, default=10)
|
65 |
+
aparser.add_argument('--print-freq', type=int, default=1000)
|
66 |
+
aparser.add_argument('infile', type=str)
|
67 |
+
aparser.add_argument('outpattern', type=str)
|
68 |
+
|
69 |
+
main(aparser.parse_args())
|
bin/gen_debug_mask_dataset.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
|
6 |
+
import PIL.Image as Image
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import tqdm
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
|
13 |
+
from saicinpainting.evaluation.utils import load_yaml
|
14 |
+
|
15 |
+
|
16 |
+
def generate_masks_for_img(infile, outmask_pattern, mask_size=200, step=0.5):
|
17 |
+
inimg = Image.open(infile)
|
18 |
+
width, height = inimg.size
|
19 |
+
step_abs = int(mask_size * step)
|
20 |
+
|
21 |
+
mask = np.zeros((height, width), dtype='uint8')
|
22 |
+
mask_i = 0
|
23 |
+
|
24 |
+
for start_vertical in range(0, height - step_abs, step_abs):
|
25 |
+
for start_horizontal in range(0, width - step_abs, step_abs):
|
26 |
+
mask[start_vertical:start_vertical + mask_size, start_horizontal:start_horizontal + mask_size] = 255
|
27 |
+
|
28 |
+
cv2.imwrite(outmask_pattern.format(mask_i), mask)
|
29 |
+
|
30 |
+
mask[start_vertical:start_vertical + mask_size, start_horizontal:start_horizontal + mask_size] = 0
|
31 |
+
mask_i += 1
|
32 |
+
|
33 |
+
|
34 |
+
def main(args):
|
35 |
+
if not args.indir.endswith('/'):
|
36 |
+
args.indir += '/'
|
37 |
+
if not args.outdir.endswith('/'):
|
38 |
+
args.outdir += '/'
|
39 |
+
|
40 |
+
config = load_yaml(args.config)
|
41 |
+
|
42 |
+
in_files = list(glob.glob(os.path.join(args.indir, '**', f'*{config.img_ext}'), recursive=True))
|
43 |
+
for infile in tqdm.tqdm(in_files):
|
44 |
+
outimg = args.outdir + infile[len(args.indir):]
|
45 |
+
outmask_pattern = outimg[:-len(config.img_ext)] + '_mask{:04d}.png'
|
46 |
+
|
47 |
+
os.makedirs(os.path.dirname(outimg), exist_ok=True)
|
48 |
+
shutil.copy2(infile, outimg)
|
49 |
+
|
50 |
+
generate_masks_for_img(infile, outmask_pattern, **config.gen_kwargs)
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
import argparse
|
55 |
+
|
56 |
+
aparser = argparse.ArgumentParser()
|
57 |
+
aparser.add_argument('config', type=str, help='Path to config for dataset generation')
|
58 |
+
aparser.add_argument('indir', type=str, help='Path to folder with images')
|
59 |
+
aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
|
60 |
+
|
61 |
+
main(aparser.parse_args())
|
bin/gen_mask_dataset.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import traceback
|
7 |
+
|
8 |
+
import PIL.Image as Image
|
9 |
+
import numpy as np
|
10 |
+
from joblib import Parallel, delayed
|
11 |
+
|
12 |
+
from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop
|
13 |
+
from saicinpainting.evaluation.utils import load_yaml, SmallMode
|
14 |
+
from saicinpainting.training.data.masks import MixedMaskGenerator
|
15 |
+
|
16 |
+
|
17 |
+
class MakeManyMasksWrapper:
|
18 |
+
def __init__(self, impl, variants_n=2):
|
19 |
+
self.impl = impl
|
20 |
+
self.variants_n = variants_n
|
21 |
+
|
22 |
+
def get_masks(self, img):
|
23 |
+
img = np.transpose(np.array(img), (2, 0, 1))
|
24 |
+
return [self.impl(img)[0] for _ in range(self.variants_n)]
|
25 |
+
|
26 |
+
|
27 |
+
def process_images(src_images, indir, outdir, config):
|
28 |
+
if config.generator_kind == 'segmentation':
|
29 |
+
mask_generator = SegmentationMask(**config.mask_generator_kwargs)
|
30 |
+
elif config.generator_kind == 'random':
|
31 |
+
variants_n = config.mask_generator_kwargs.pop('variants_n', 2)
|
32 |
+
mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**config.mask_generator_kwargs),
|
33 |
+
variants_n=variants_n)
|
34 |
+
else:
|
35 |
+
raise ValueError(f'Unexpected generator kind: {config.generator_kind}')
|
36 |
+
|
37 |
+
max_tamper_area = config.get('max_tamper_area', 1)
|
38 |
+
|
39 |
+
for infile in src_images:
|
40 |
+
try:
|
41 |
+
file_relpath = infile[len(indir):]
|
42 |
+
img_outpath = os.path.join(outdir, file_relpath)
|
43 |
+
os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
|
44 |
+
|
45 |
+
image = Image.open(infile).convert('RGB')
|
46 |
+
|
47 |
+
# scale input image to output resolution and filter smaller images
|
48 |
+
if min(image.size) < config.cropping.out_min_size:
|
49 |
+
handle_small_mode = SmallMode(config.cropping.handle_small_mode)
|
50 |
+
if handle_small_mode == SmallMode.DROP:
|
51 |
+
continue
|
52 |
+
elif handle_small_mode == SmallMode.UPSCALE:
|
53 |
+
factor = config.cropping.out_min_size / min(image.size)
|
54 |
+
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
55 |
+
image = image.resize(out_size, resample=Image.BICUBIC)
|
56 |
+
else:
|
57 |
+
factor = config.cropping.out_min_size / min(image.size)
|
58 |
+
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
59 |
+
image = image.resize(out_size, resample=Image.BICUBIC)
|
60 |
+
|
61 |
+
# generate and select masks
|
62 |
+
src_masks = mask_generator.get_masks(image)
|
63 |
+
|
64 |
+
filtered_image_mask_pairs = []
|
65 |
+
for cur_mask in src_masks:
|
66 |
+
if config.cropping.out_square_crop:
|
67 |
+
(crop_left,
|
68 |
+
crop_top,
|
69 |
+
crop_right,
|
70 |
+
crop_bottom) = propose_random_square_crop(cur_mask,
|
71 |
+
min_overlap=config.cropping.crop_min_overlap)
|
72 |
+
cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
|
73 |
+
cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
|
74 |
+
else:
|
75 |
+
cur_image = image
|
76 |
+
|
77 |
+
if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
|
78 |
+
continue
|
79 |
+
|
80 |
+
filtered_image_mask_pairs.append((cur_image, cur_mask))
|
81 |
+
|
82 |
+
mask_indices = np.random.choice(len(filtered_image_mask_pairs),
|
83 |
+
size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
|
84 |
+
replace=False)
|
85 |
+
|
86 |
+
# crop masks; save masks together with input image
|
87 |
+
mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
|
88 |
+
for i, idx in enumerate(mask_indices):
|
89 |
+
cur_image, cur_mask = filtered_image_mask_pairs[idx]
|
90 |
+
cur_basename = mask_basename + f'_crop{i:03d}'
|
91 |
+
Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
|
92 |
+
mode='L').save(cur_basename + f'_mask{i:03d}.png')
|
93 |
+
cur_image.save(cur_basename + '.png')
|
94 |
+
except KeyboardInterrupt:
|
95 |
+
return
|
96 |
+
except Exception as ex:
|
97 |
+
print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')
|
98 |
+
|
99 |
+
|
100 |
+
def main(args):
|
101 |
+
if not args.indir.endswith('/'):
|
102 |
+
args.indir += '/'
|
103 |
+
|
104 |
+
os.makedirs(args.outdir, exist_ok=True)
|
105 |
+
|
106 |
+
config = load_yaml(args.config)
|
107 |
+
|
108 |
+
in_files = list(glob.glob(os.path.join(args.indir, '**', f'*.{args.ext}'), recursive=True))
|
109 |
+
if args.n_jobs == 0:
|
110 |
+
process_images(in_files, args.indir, args.outdir, config)
|
111 |
+
else:
|
112 |
+
in_files_n = len(in_files)
|
113 |
+
chunk_size = in_files_n // args.n_jobs + (1 if in_files_n % args.n_jobs > 0 else 0)
|
114 |
+
Parallel(n_jobs=args.n_jobs)(
|
115 |
+
delayed(process_images)(in_files[start:start+chunk_size], args.indir, args.outdir, config)
|
116 |
+
for start in range(0, len(in_files), chunk_size)
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
import argparse
|
122 |
+
|
123 |
+
aparser = argparse.ArgumentParser()
|
124 |
+
aparser.add_argument('config', type=str, help='Path to config for dataset generation')
|
125 |
+
aparser.add_argument('indir', type=str, help='Path to folder with images')
|
126 |
+
aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
|
127 |
+
aparser.add_argument('--n-jobs', type=int, default=0, help='How many processes to use')
|
128 |
+
aparser.add_argument('--ext', type=str, default='jpg', help='Input image extension')
|
129 |
+
|
130 |
+
main(aparser.parse_args())
|
bin/gen_mask_dataset_hydra.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import traceback
|
7 |
+
import hydra
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
import PIL.Image as Image
|
11 |
+
import numpy as np
|
12 |
+
from joblib import Parallel, delayed
|
13 |
+
|
14 |
+
from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop
|
15 |
+
from saicinpainting.evaluation.utils import load_yaml, SmallMode
|
16 |
+
from saicinpainting.training.data.masks import MixedMaskGenerator
|
17 |
+
|
18 |
+
|
19 |
+
class MakeManyMasksWrapper:
|
20 |
+
def __init__(self, impl, variants_n=2):
|
21 |
+
self.impl = impl
|
22 |
+
self.variants_n = variants_n
|
23 |
+
|
24 |
+
def get_masks(self, img):
|
25 |
+
img = np.transpose(np.array(img), (2, 0, 1))
|
26 |
+
return [self.impl(img)[0] for _ in range(self.variants_n)]
|
27 |
+
|
28 |
+
|
29 |
+
def process_images(src_images, indir, outdir, config):
|
30 |
+
if config.generator_kind == 'segmentation':
|
31 |
+
mask_generator = SegmentationMask(**config.mask_generator_kwargs)
|
32 |
+
elif config.generator_kind == 'random':
|
33 |
+
mask_generator_kwargs = OmegaConf.to_container(config.mask_generator_kwargs, resolve=True)
|
34 |
+
variants_n = mask_generator_kwargs.pop('variants_n', 2)
|
35 |
+
mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**mask_generator_kwargs),
|
36 |
+
variants_n=variants_n)
|
37 |
+
else:
|
38 |
+
raise ValueError(f'Unexpected generator kind: {config.generator_kind}')
|
39 |
+
|
40 |
+
max_tamper_area = config.get('max_tamper_area', 1)
|
41 |
+
|
42 |
+
for infile in src_images:
|
43 |
+
try:
|
44 |
+
file_relpath = infile[len(indir):]
|
45 |
+
img_outpath = os.path.join(outdir, file_relpath)
|
46 |
+
os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
|
47 |
+
|
48 |
+
image = Image.open(infile).convert('RGB')
|
49 |
+
|
50 |
+
# scale input image to output resolution and filter smaller images
|
51 |
+
if min(image.size) < config.cropping.out_min_size:
|
52 |
+
handle_small_mode = SmallMode(config.cropping.handle_small_mode)
|
53 |
+
if handle_small_mode == SmallMode.DROP:
|
54 |
+
continue
|
55 |
+
elif handle_small_mode == SmallMode.UPSCALE:
|
56 |
+
factor = config.cropping.out_min_size / min(image.size)
|
57 |
+
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
58 |
+
image = image.resize(out_size, resample=Image.BICUBIC)
|
59 |
+
else:
|
60 |
+
factor = config.cropping.out_min_size / min(image.size)
|
61 |
+
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
62 |
+
image = image.resize(out_size, resample=Image.BICUBIC)
|
63 |
+
|
64 |
+
# generate and select masks
|
65 |
+
src_masks = mask_generator.get_masks(image)
|
66 |
+
|
67 |
+
filtered_image_mask_pairs = []
|
68 |
+
for cur_mask in src_masks:
|
69 |
+
if config.cropping.out_square_crop:
|
70 |
+
(crop_left,
|
71 |
+
crop_top,
|
72 |
+
crop_right,
|
73 |
+
crop_bottom) = propose_random_square_crop(cur_mask,
|
74 |
+
min_overlap=config.cropping.crop_min_overlap)
|
75 |
+
cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
|
76 |
+
cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
|
77 |
+
else:
|
78 |
+
cur_image = image
|
79 |
+
|
80 |
+
if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
|
81 |
+
continue
|
82 |
+
|
83 |
+
filtered_image_mask_pairs.append((cur_image, cur_mask))
|
84 |
+
|
85 |
+
mask_indices = np.random.choice(len(filtered_image_mask_pairs),
|
86 |
+
size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
|
87 |
+
replace=False)
|
88 |
+
|
89 |
+
# crop masks; save masks together with input image
|
90 |
+
mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
|
91 |
+
for i, idx in enumerate(mask_indices):
|
92 |
+
cur_image, cur_mask = filtered_image_mask_pairs[idx]
|
93 |
+
cur_basename = mask_basename + f'_crop{i:03d}'
|
94 |
+
Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
|
95 |
+
mode='L').save(cur_basename + f'_mask{i:03d}.png')
|
96 |
+
cur_image.save(cur_basename + '.png')
|
97 |
+
except KeyboardInterrupt:
|
98 |
+
return
|
99 |
+
except Exception as ex:
|
100 |
+
print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')
|
101 |
+
|
102 |
+
|
103 |
+
@hydra.main(config_path='../configs/data_gen/whydra', config_name='random_medium_256.yaml')
|
104 |
+
def main(config: OmegaConf):
|
105 |
+
if not config.indir.endswith('/'):
|
106 |
+
config.indir += '/'
|
107 |
+
|
108 |
+
os.makedirs(config.outdir, exist_ok=True)
|
109 |
+
|
110 |
+
in_files = list(glob.glob(os.path.join(config.indir, '**', f'*.{config.location.extension}'),
|
111 |
+
recursive=True))
|
112 |
+
if config.n_jobs == 0:
|
113 |
+
process_images(in_files, config.indir, config.outdir, config)
|
114 |
+
else:
|
115 |
+
in_files_n = len(in_files)
|
116 |
+
chunk_size = in_files_n // config.n_jobs + (1 if in_files_n % config.n_jobs > 0 else 0)
|
117 |
+
Parallel(n_jobs=config.n_jobs)(
|
118 |
+
delayed(process_images)(in_files[start:start+chunk_size], config.indir, config.outdir, config)
|
119 |
+
for start in range(0, len(in_files), chunk_size)
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
if __name__ == '__main__':
|
124 |
+
main()
|
bin/gen_outpainting_dataset.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import glob
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import sys
|
7 |
+
import traceback
|
8 |
+
|
9 |
+
from saicinpainting.evaluation.data import load_image
|
10 |
+
from saicinpainting.evaluation.utils import move_to_device
|
11 |
+
|
12 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
13 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
14 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
15 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
16 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import hydra
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import tqdm
|
23 |
+
import yaml
|
24 |
+
from omegaconf import OmegaConf
|
25 |
+
from torch.utils.data._utils.collate import default_collate
|
26 |
+
|
27 |
+
from saicinpainting.training.data.datasets import make_default_val_dataset
|
28 |
+
from saicinpainting.training.trainers import load_checkpoint
|
29 |
+
from saicinpainting.utils import register_debug_signal_handlers
|
30 |
+
|
31 |
+
LOGGER = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
def main(args):
|
35 |
+
try:
|
36 |
+
if not args.indir.endswith('/'):
|
37 |
+
args.indir += '/'
|
38 |
+
|
39 |
+
for in_img in glob.glob(os.path.join(args.indir, '**', '*' + args.img_suffix), recursive=True):
|
40 |
+
if 'mask' in os.path.basename(in_img):
|
41 |
+
continue
|
42 |
+
|
43 |
+
out_img_path = os.path.join(args.outdir, os.path.splitext(in_img[len(args.indir):])[0] + '.png')
|
44 |
+
out_mask_path = f'{os.path.splitext(out_img_path)[0]}_mask.png'
|
45 |
+
|
46 |
+
os.makedirs(os.path.dirname(out_img_path), exist_ok=True)
|
47 |
+
|
48 |
+
img = load_image(in_img)
|
49 |
+
height, width = img.shape[1:]
|
50 |
+
pad_h, pad_w = int(height * args.coef / 2), int(width * args.coef / 2)
|
51 |
+
|
52 |
+
mask = np.zeros((height, width), dtype='uint8')
|
53 |
+
|
54 |
+
if args.expand:
|
55 |
+
img = np.pad(img, ((0, 0), (pad_h, pad_h), (pad_w, pad_w)))
|
56 |
+
mask = np.pad(mask, ((pad_h, pad_h), (pad_w, pad_w)), mode='constant', constant_values=255)
|
57 |
+
else:
|
58 |
+
mask[:pad_h] = 255
|
59 |
+
mask[-pad_h:] = 255
|
60 |
+
mask[:, :pad_w] = 255
|
61 |
+
mask[:, -pad_w:] = 255
|
62 |
+
|
63 |
+
# img = np.pad(img, ((0, 0), (pad_h * 2, pad_h * 2), (pad_w * 2, pad_w * 2)), mode='symmetric')
|
64 |
+
# mask = np.pad(mask, ((pad_h * 2, pad_h * 2), (pad_w * 2, pad_w * 2)), mode = 'symmetric')
|
65 |
+
|
66 |
+
img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0, 255).astype('uint8')
|
67 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
68 |
+
cv2.imwrite(out_img_path, img)
|
69 |
+
|
70 |
+
cv2.imwrite(out_mask_path, mask)
|
71 |
+
except KeyboardInterrupt:
|
72 |
+
LOGGER.warning('Interrupted by user')
|
73 |
+
except Exception as ex:
|
74 |
+
LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
|
75 |
+
sys.exit(1)
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
import argparse
|
80 |
+
|
81 |
+
aparser = argparse.ArgumentParser()
|
82 |
+
aparser.add_argument('indir', type=str, help='Root directory with images')
|
83 |
+
aparser.add_argument('outdir', type=str, help='Where to store results')
|
84 |
+
aparser.add_argument('--img-suffix', type=str, default='.png', help='Input image extension')
|
85 |
+
aparser.add_argument('--expand', action='store_true', help='Generate mask by padding (true) or by cropping (false)')
|
86 |
+
aparser.add_argument('--coef', type=float, default=0.2, help='How much to crop/expand in order to get masks')
|
87 |
+
|
88 |
+
main(aparser.parse_args())
|
bin/make_checkpoint.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def get_checkpoint_files(s):
|
10 |
+
s = s.strip()
|
11 |
+
if ',' in s:
|
12 |
+
return [get_checkpoint_files(chunk) for chunk in s.split(',')]
|
13 |
+
return 'last.ckpt' if s == 'last' else f'{s}.ckpt'
|
14 |
+
|
15 |
+
|
16 |
+
def main(args):
|
17 |
+
checkpoint_fnames = get_checkpoint_files(args.epochs)
|
18 |
+
if isinstance(checkpoint_fnames, str):
|
19 |
+
checkpoint_fnames = [checkpoint_fnames]
|
20 |
+
assert len(checkpoint_fnames) >= 1
|
21 |
+
|
22 |
+
checkpoint_path = os.path.join(args.indir, 'models', checkpoint_fnames[0])
|
23 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
24 |
+
del checkpoint['optimizer_states']
|
25 |
+
|
26 |
+
if len(checkpoint_fnames) > 1:
|
27 |
+
for fname in checkpoint_fnames[1:]:
|
28 |
+
print('sum', fname)
|
29 |
+
sum_tensors_cnt = 0
|
30 |
+
other_cp = torch.load(os.path.join(args.indir, 'models', fname), map_location='cpu')
|
31 |
+
for k in checkpoint['state_dict'].keys():
|
32 |
+
if checkpoint['state_dict'][k].dtype is torch.float:
|
33 |
+
checkpoint['state_dict'][k].data.add_(other_cp['state_dict'][k].data)
|
34 |
+
sum_tensors_cnt += 1
|
35 |
+
print('summed', sum_tensors_cnt, 'tensors')
|
36 |
+
|
37 |
+
for k in checkpoint['state_dict'].keys():
|
38 |
+
if checkpoint['state_dict'][k].dtype is torch.float:
|
39 |
+
checkpoint['state_dict'][k].data.mul_(1 / float(len(checkpoint_fnames)))
|
40 |
+
|
41 |
+
state_dict = checkpoint['state_dict']
|
42 |
+
|
43 |
+
if not args.leave_discriminators:
|
44 |
+
for k in list(state_dict.keys()):
|
45 |
+
if k.startswith('discriminator.'):
|
46 |
+
del state_dict[k]
|
47 |
+
|
48 |
+
if not args.leave_losses:
|
49 |
+
for k in list(state_dict.keys()):
|
50 |
+
if k.startswith('loss_'):
|
51 |
+
del state_dict[k]
|
52 |
+
|
53 |
+
out_checkpoint_path = os.path.join(args.outdir, 'models', 'best.ckpt')
|
54 |
+
os.makedirs(os.path.dirname(out_checkpoint_path), exist_ok=True)
|
55 |
+
|
56 |
+
torch.save(checkpoint, out_checkpoint_path)
|
57 |
+
|
58 |
+
shutil.copy2(os.path.join(args.indir, 'config.yaml'),
|
59 |
+
os.path.join(args.outdir, 'config.yaml'))
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
import argparse
|
64 |
+
|
65 |
+
aparser = argparse.ArgumentParser()
|
66 |
+
aparser.add_argument('indir',
|
67 |
+
help='Path to directory with output of training '
|
68 |
+
'(i.e. directory, which has samples, modules, config.yaml and train.log')
|
69 |
+
aparser.add_argument('outdir',
|
70 |
+
help='Where to put minimal checkpoint, which can be consumed by "bin/predict.py"')
|
71 |
+
aparser.add_argument('--epochs', type=str, default='last',
|
72 |
+
help='Which checkpoint to take. '
|
73 |
+
'Can be "last" or integer - number of epoch')
|
74 |
+
aparser.add_argument('--leave-discriminators', action='store_true',
|
75 |
+
help='If enabled, the state of discriminators will not be removed from the checkpoint')
|
76 |
+
aparser.add_argument('--leave-losses', action='store_true',
|
77 |
+
help='If enabled, weights of nn-based losses (e.g. perceptual) will not be removed')
|
78 |
+
|
79 |
+
main(aparser.parse_args())
|
bin/mask_example.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
from skimage import io
|
3 |
+
from skimage.transform import resize
|
4 |
+
|
5 |
+
from saicinpainting.evaluation.masks.mask import SegmentationMask
|
6 |
+
|
7 |
+
im = io.imread('imgs/ex4.jpg')
|
8 |
+
im = resize(im, (512, 1024), anti_aliasing=True)
|
9 |
+
mask_seg = SegmentationMask(num_variants_per_mask=10)
|
10 |
+
mask_examples = mask_seg.get_masks(im)
|
11 |
+
for i, example in enumerate(mask_examples):
|
12 |
+
plt.imshow(example)
|
13 |
+
plt.show()
|
14 |
+
plt.imsave(f'tmp/img_masks/{i}.png', example)
|
bin/paper_runfiles/blur_tests.sh
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/env bash
|
2 |
+
#
|
3 |
+
## !!! file set to make test_large_30k from the vanilla test_large: configs/test_large_30k.lst
|
4 |
+
#
|
5 |
+
## paths to data are valid for mml7
|
6 |
+
#PLACES_ROOT="/data/inpainting/Places365"
|
7 |
+
#OUT_DIR="/data/inpainting/paper_data/Places365_val_test"
|
8 |
+
#
|
9 |
+
#source "$(dirname $0)/env.sh"
|
10 |
+
#
|
11 |
+
#for datadir in test_large_30k # val_large
|
12 |
+
#do
|
13 |
+
# for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
|
14 |
+
# do
|
15 |
+
# "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
|
16 |
+
# "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 8
|
17 |
+
#
|
18 |
+
# "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
19 |
+
# done
|
20 |
+
#
|
21 |
+
# for conf in segm_256 segm_512
|
22 |
+
# do
|
23 |
+
# "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
|
24 |
+
# "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 2
|
25 |
+
#
|
26 |
+
# "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
27 |
+
# done
|
28 |
+
#done
|
29 |
+
#
|
30 |
+
#IN_DIR="/data/inpainting/paper_data/Places365_val_test/test_large_30k/random_medium_512"
|
31 |
+
#PRED_DIR="/data/inpainting/predictions/final/images/r.suvorov_2021-03-05_17-08-35_train_ablv2_work_resume_epoch37/random_medium_512"
|
32 |
+
#BLUR_OUT_DIR="/data/inpainting/predictions/final/blur/images"
|
33 |
+
#
|
34 |
+
#for b in 0.1
|
35 |
+
#
|
36 |
+
#"$BINDIR/blur_predicts.py" "$BASEDIR/../../configs/eval2.yaml" "$CUR_IN_DIR" "$CUR_OUT_DIR" "$CUR_EVAL_DIR"
|
37 |
+
#
|
bin/paper_runfiles/env.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DIRNAME="$(dirname $0)"
|
2 |
+
DIRNAME="$(realpath ""$DIRNAME"")"
|
3 |
+
|
4 |
+
BINDIR="$DIRNAME/.."
|
5 |
+
SRCDIR="$BINDIR/.."
|
6 |
+
CONFIGDIR="$SRCDIR/configs"
|
7 |
+
|
8 |
+
export PYTHONPATH="$SRCDIR:$PYTHONPATH"
|
bin/paper_runfiles/find_best_checkpoint.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
|
4 |
+
import os
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
|
7 |
+
|
8 |
+
def ssim_fid100_f1(metrics, fid_scale=100):
|
9 |
+
ssim = metrics.loc['total', 'ssim']['mean']
|
10 |
+
fid = metrics.loc['total', 'fid']['mean']
|
11 |
+
fid_rel = max(0, fid_scale - fid) / fid_scale
|
12 |
+
f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
|
13 |
+
return f1
|
14 |
+
|
15 |
+
|
16 |
+
def find_best_checkpoint(model_list, models_dir):
|
17 |
+
with open(model_list) as f:
|
18 |
+
models = [m.strip() for m in f.readlines()]
|
19 |
+
with open(f'{model_list}_best', 'w') as f:
|
20 |
+
for model in models:
|
21 |
+
print(model)
|
22 |
+
best_f1 = 0
|
23 |
+
best_epoch = 0
|
24 |
+
best_step = 0
|
25 |
+
with open(os.path.join(models_dir, model, 'train.log')) as fm:
|
26 |
+
lines = fm.readlines()
|
27 |
+
for line_index in range(len(lines)):
|
28 |
+
line = lines[line_index]
|
29 |
+
if 'Validation metrics after epoch' in line:
|
30 |
+
sharp_index = line.index('#')
|
31 |
+
cur_ep = line[sharp_index + 1:]
|
32 |
+
comma_index = cur_ep.index(',')
|
33 |
+
cur_ep = int(cur_ep[:comma_index])
|
34 |
+
total_index = line.index('total ')
|
35 |
+
step = int(line[total_index:].split()[1].strip())
|
36 |
+
total_line = lines[line_index + 5]
|
37 |
+
if not total_line.startswith('total'):
|
38 |
+
continue
|
39 |
+
words = total_line.strip().split()
|
40 |
+
f1 = float(words[-1])
|
41 |
+
print(f'\tEpoch: {cur_ep}, f1={f1}')
|
42 |
+
if f1 > best_f1:
|
43 |
+
best_f1 = f1
|
44 |
+
best_epoch = cur_ep
|
45 |
+
best_step = step
|
46 |
+
f.write(f'{model}\t{best_epoch}\t{best_step}\t{best_f1}\n')
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == '__main__':
|
50 |
+
parser = ArgumentParser()
|
51 |
+
parser.add_argument('model_list')
|
52 |
+
parser.add_argument('models_dir')
|
53 |
+
args = parser.parse_args()
|
54 |
+
find_best_checkpoint(args.model_list, args.models_dir)
|
bin/paper_runfiles/generate_test_celeba-hq.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml-ws01
|
4 |
+
OUT_DIR="/media/inpainting/paper_data/CelebA-HQ_val_test"
|
5 |
+
|
6 |
+
source "$(dirname $0)/env.sh"
|
7 |
+
|
8 |
+
for datadir in "val" "test"
|
9 |
+
do
|
10 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
|
11 |
+
do
|
12 |
+
"$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-celeba-hq \
|
13 |
+
location.out_dir=$OUT_DIR cropping.out_square_crop=False
|
14 |
+
|
15 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
16 |
+
done
|
17 |
+
done
|
bin/paper_runfiles/generate_test_ffhq.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml-ws01
|
4 |
+
OUT_DIR="/media/inpainting/paper_data/FFHQ_val"
|
5 |
+
|
6 |
+
source "$(dirname $0)/env.sh"
|
7 |
+
|
8 |
+
for datadir in test
|
9 |
+
do
|
10 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
|
11 |
+
do
|
12 |
+
"$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-ffhq \
|
13 |
+
location.out_dir=$OUT_DIR cropping.out_square_crop=False
|
14 |
+
|
15 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
16 |
+
done
|
17 |
+
done
|
bin/paper_runfiles/generate_test_paris.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml-ws01
|
4 |
+
OUT_DIR="/media/inpainting/paper_data/Paris_StreetView_Dataset_val"
|
5 |
+
|
6 |
+
source "$(dirname $0)/env.sh"
|
7 |
+
|
8 |
+
for datadir in paris_eval_gt
|
9 |
+
do
|
10 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 segm_256
|
11 |
+
do
|
12 |
+
"$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-paris \
|
13 |
+
location.out_dir=OUT_DIR cropping.out_square_crop=False cropping.out_min_size=227
|
14 |
+
|
15 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
16 |
+
done
|
17 |
+
done
|
bin/paper_runfiles/generate_test_paris_256.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml-ws01
|
4 |
+
OUT_DIR="/media/inpainting/paper_data/Paris_StreetView_Dataset_val_256"
|
5 |
+
|
6 |
+
source "$(dirname $0)/env.sh"
|
7 |
+
|
8 |
+
for datadir in paris_eval_gt
|
9 |
+
do
|
10 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 segm_256
|
11 |
+
do
|
12 |
+
"$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-paris \
|
13 |
+
location.out_dir=$OUT_DIR cropping.out_square_crop=False cropping.out_min_size=256
|
14 |
+
|
15 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
16 |
+
done
|
17 |
+
done
|
bin/paper_runfiles/generate_val_test.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# !!! file set to make test_large_30k from the vanilla test_large: configs/test_large_30k.lst
|
4 |
+
|
5 |
+
# paths to data are valid for mml7
|
6 |
+
PLACES_ROOT="/data/inpainting/Places365"
|
7 |
+
OUT_DIR="/data/inpainting/paper_data/Places365_val_test"
|
8 |
+
|
9 |
+
source "$(dirname $0)/env.sh"
|
10 |
+
|
11 |
+
for datadir in test_large_30k # val_large
|
12 |
+
do
|
13 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
|
14 |
+
do
|
15 |
+
"$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
|
16 |
+
"$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 8
|
17 |
+
|
18 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
19 |
+
done
|
20 |
+
|
21 |
+
for conf in segm_256 segm_512
|
22 |
+
do
|
23 |
+
"$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
|
24 |
+
"$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 2
|
25 |
+
|
26 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
27 |
+
done
|
28 |
+
done
|
bin/paper_runfiles/predict_inner_features.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml7
|
4 |
+
|
5 |
+
source "$(dirname $0)/env.sh"
|
6 |
+
|
7 |
+
"$BINDIR/predict_inner_features.py" \
|
8 |
+
-cn default_inner_features_ffc \
|
9 |
+
model.path="/data/inpainting/paper_data/final_models/ours/r.suvorov_2021-03-05_17-34-05_train_ablv2_work_ffc075_resume_epoch39" \
|
10 |
+
indir="/data/inpainting/paper_data/inner_features_vis/input/" \
|
11 |
+
outdir="/data/inpainting/paper_data/inner_features_vis/output/ffc" \
|
12 |
+
dataset.img_suffix=.png
|
13 |
+
|
14 |
+
|
15 |
+
"$BINDIR/predict_inner_features.py" \
|
16 |
+
-cn default_inner_features_work \
|
17 |
+
model.path="/data/inpainting/paper_data/final_models/ours/r.suvorov_2021-03-05_17-08-35_train_ablv2_work_resume_epoch37" \
|
18 |
+
indir="/data/inpainting/paper_data/inner_features_vis/input/" \
|
19 |
+
outdir="/data/inpainting/paper_data/inner_features_vis/output/work" \
|
20 |
+
dataset.img_suffix=.png
|
bin/paper_runfiles/update_test_data_stats.sh
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml7
|
4 |
+
|
5 |
+
source "$(dirname $0)/env.sh"
|
6 |
+
|
7 |
+
#INDIR="/data/inpainting/paper_data/Places365_val_test/test_large_30k"
|
8 |
+
#
|
9 |
+
#for dataset in random_medium_256 random_medium_512 random_thick_256 random_thick_512 random_thin_256 random_thin_512
|
10 |
+
#do
|
11 |
+
# "$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
|
12 |
+
#done
|
13 |
+
#
|
14 |
+
#"$BINDIR/calc_dataset_stats.py" "/data/inpainting/evalset2" "/data/inpainting/evalset2_stats2"
|
15 |
+
|
16 |
+
|
17 |
+
INDIR="/data/inpainting/paper_data/CelebA-HQ_val_test/test"
|
18 |
+
|
19 |
+
for dataset in random_medium_256 random_thick_256 random_thin_256
|
20 |
+
do
|
21 |
+
"$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
|
22 |
+
done
|
23 |
+
|
24 |
+
|
25 |
+
INDIR="/data/inpainting/paper_data/Paris_StreetView_Dataset_val_256/paris_eval_gt"
|
26 |
+
|
27 |
+
for dataset in random_medium_256 random_thick_256 random_thin_256
|
28 |
+
do
|
29 |
+
"$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
|
30 |
+
done
|
bin/predict.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Example command:
|
4 |
+
# ./bin/predict.py \
|
5 |
+
# model.path=<path to checkpoint, prepared by make_checkpoint.py> \
|
6 |
+
# indir=<path to input data> \
|
7 |
+
# outdir=<where to store predicts>
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import traceback
|
13 |
+
|
14 |
+
from saicinpainting.evaluation.utils import move_to_device
|
15 |
+
|
16 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
17 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
18 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
19 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
20 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
21 |
+
|
22 |
+
import cv2
|
23 |
+
import hydra
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
import tqdm
|
27 |
+
import yaml
|
28 |
+
from omegaconf import OmegaConf
|
29 |
+
from torch.utils.data._utils.collate import default_collate
|
30 |
+
|
31 |
+
from saicinpainting.training.data.datasets import make_default_val_dataset
|
32 |
+
from saicinpainting.training.trainers import load_checkpoint
|
33 |
+
from saicinpainting.utils import register_debug_signal_handlers
|
34 |
+
|
35 |
+
LOGGER = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
@hydra.main(config_path='../configs/prediction', config_name='default.yaml')
|
39 |
+
def main(predict_config: OmegaConf):
|
40 |
+
try:
|
41 |
+
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
|
42 |
+
|
43 |
+
device = torch.device(predict_config.device)
|
44 |
+
|
45 |
+
train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
|
46 |
+
with open(train_config_path, 'r') as f:
|
47 |
+
train_config = OmegaConf.create(yaml.safe_load(f))
|
48 |
+
|
49 |
+
train_config.training_model.predict_only = True
|
50 |
+
|
51 |
+
out_ext = predict_config.get('out_ext', '.png')
|
52 |
+
|
53 |
+
checkpoint_path = os.path.join(predict_config.model.path,
|
54 |
+
'models',
|
55 |
+
predict_config.model.checkpoint)
|
56 |
+
model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
|
57 |
+
model.freeze()
|
58 |
+
model.to(device)
|
59 |
+
|
60 |
+
if not predict_config.indir.endswith('/'):
|
61 |
+
predict_config.indir += '/'
|
62 |
+
|
63 |
+
dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
|
64 |
+
with torch.no_grad():
|
65 |
+
for img_i in tqdm.trange(len(dataset)):
|
66 |
+
mask_fname = dataset.mask_filenames[img_i]
|
67 |
+
cur_out_fname = os.path.join(
|
68 |
+
predict_config.outdir,
|
69 |
+
os.path.splitext(mask_fname[len(predict_config.indir):])[0] + out_ext
|
70 |
+
)
|
71 |
+
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
|
72 |
+
|
73 |
+
batch = move_to_device(default_collate([dataset[img_i]]), device)
|
74 |
+
batch['mask'] = (batch['mask'] > 0) * 1
|
75 |
+
batch = model(batch)
|
76 |
+
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
|
77 |
+
|
78 |
+
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
|
79 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
80 |
+
cv2.imwrite(cur_out_fname, cur_res)
|
81 |
+
except KeyboardInterrupt:
|
82 |
+
LOGGER.warning('Interrupted by user')
|
83 |
+
except Exception as ex:
|
84 |
+
LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
|
85 |
+
sys.exit(1)
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == '__main__':
|
89 |
+
main()
|
bin/predict_inner_features.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Example command:
|
4 |
+
# ./bin/predict.py \
|
5 |
+
# model.path=<path to checkpoint, prepared by make_checkpoint.py> \
|
6 |
+
# indir=<path to input data> \
|
7 |
+
# outdir=<where to store predicts>
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import traceback
|
13 |
+
|
14 |
+
from saicinpainting.evaluation.utils import move_to_device
|
15 |
+
|
16 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
17 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
18 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
19 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
20 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
21 |
+
|
22 |
+
import cv2
|
23 |
+
import hydra
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
import tqdm
|
27 |
+
import yaml
|
28 |
+
from omegaconf import OmegaConf
|
29 |
+
from torch.utils.data._utils.collate import default_collate
|
30 |
+
|
31 |
+
from saicinpainting.training.data.datasets import make_default_val_dataset
|
32 |
+
from saicinpainting.training.trainers import load_checkpoint, DefaultInpaintingTrainingModule
|
33 |
+
from saicinpainting.utils import register_debug_signal_handlers, get_shape
|
34 |
+
|
35 |
+
LOGGER = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
@hydra.main(config_path='../configs/prediction', config_name='default_inner_features.yaml')
|
39 |
+
def main(predict_config: OmegaConf):
|
40 |
+
try:
|
41 |
+
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
|
42 |
+
|
43 |
+
device = torch.device(predict_config.device)
|
44 |
+
|
45 |
+
train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
|
46 |
+
with open(train_config_path, 'r') as f:
|
47 |
+
train_config = OmegaConf.create(yaml.safe_load(f))
|
48 |
+
|
49 |
+
checkpoint_path = os.path.join(predict_config.model.path, 'models', predict_config.model.checkpoint)
|
50 |
+
model = load_checkpoint(train_config, checkpoint_path, strict=False)
|
51 |
+
model.freeze()
|
52 |
+
model.to(device)
|
53 |
+
|
54 |
+
assert isinstance(model, DefaultInpaintingTrainingModule), 'Only DefaultInpaintingTrainingModule is supported'
|
55 |
+
assert isinstance(getattr(model.generator, 'model', None), torch.nn.Sequential)
|
56 |
+
|
57 |
+
if not predict_config.indir.endswith('/'):
|
58 |
+
predict_config.indir += '/'
|
59 |
+
|
60 |
+
dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
|
61 |
+
|
62 |
+
max_level = max(predict_config.levels)
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
for img_i in tqdm.trange(len(dataset)):
|
66 |
+
mask_fname = dataset.mask_filenames[img_i]
|
67 |
+
cur_out_fname = os.path.join(predict_config.outdir, os.path.splitext(mask_fname[len(predict_config.indir):])[0])
|
68 |
+
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
|
69 |
+
|
70 |
+
batch = move_to_device(default_collate([dataset[img_i]]), device)
|
71 |
+
|
72 |
+
img = batch['image']
|
73 |
+
mask = batch['mask']
|
74 |
+
mask[:] = 0
|
75 |
+
mask_h, mask_w = mask.shape[-2:]
|
76 |
+
mask[:, :,
|
77 |
+
mask_h // 2 - predict_config.hole_radius : mask_h // 2 + predict_config.hole_radius,
|
78 |
+
mask_w // 2 - predict_config.hole_radius : mask_w // 2 + predict_config.hole_radius] = 1
|
79 |
+
|
80 |
+
masked_img = torch.cat([img * (1 - mask), mask], dim=1)
|
81 |
+
|
82 |
+
feats = masked_img
|
83 |
+
for level_i, level in enumerate(model.generator.model):
|
84 |
+
feats = level(feats)
|
85 |
+
if level_i in predict_config.levels:
|
86 |
+
cur_feats = torch.cat([f for f in feats if torch.is_tensor(f)], dim=1) \
|
87 |
+
if isinstance(feats, tuple) else feats
|
88 |
+
|
89 |
+
if predict_config.slice_channels:
|
90 |
+
cur_feats = cur_feats[:, slice(*predict_config.slice_channels)]
|
91 |
+
|
92 |
+
cur_feat = cur_feats.pow(2).mean(1).pow(0.5).clone()
|
93 |
+
cur_feat -= cur_feat.min()
|
94 |
+
cur_feat /= cur_feat.std()
|
95 |
+
cur_feat = cur_feat.clamp(0, 1) / 1
|
96 |
+
cur_feat = cur_feat.cpu().numpy()[0]
|
97 |
+
cur_feat *= 255
|
98 |
+
cur_feat = np.clip(cur_feat, 0, 255).astype('uint8')
|
99 |
+
cv2.imwrite(cur_out_fname + f'_lev{level_i:02d}_norm.png', cur_feat)
|
100 |
+
|
101 |
+
# for channel_i in predict_config.channels:
|
102 |
+
#
|
103 |
+
# cur_feat = cur_feats[0, channel_i].clone().detach().cpu().numpy()
|
104 |
+
# cur_feat -= cur_feat.min()
|
105 |
+
# cur_feat /= cur_feat.max()
|
106 |
+
# cur_feat *= 255
|
107 |
+
# cur_feat = np.clip(cur_feat, 0, 255).astype('uint8')
|
108 |
+
# cv2.imwrite(cur_out_fname + f'_lev{level_i}_ch{channel_i}.png', cur_feat)
|
109 |
+
elif level_i >= max_level:
|
110 |
+
break
|
111 |
+
except KeyboardInterrupt:
|
112 |
+
LOGGER.warning('Interrupted by user')
|
113 |
+
except Exception as ex:
|
114 |
+
LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
|
115 |
+
sys.exit(1)
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == '__main__':
|
119 |
+
main()
|
bin/report_from_tb.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
|
7 |
+
import tensorflow as tf
|
8 |
+
from torch.utils.tensorboard import SummaryWriter
|
9 |
+
|
10 |
+
|
11 |
+
GROUPING_RULES = [
|
12 |
+
re.compile(r'^(?P<group>train|test|val|extra_val_.*?(256|512))_(?P<title>.*)', re.I)
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
DROP_RULES = [
|
17 |
+
re.compile(r'_std$', re.I)
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def need_drop(tag):
|
22 |
+
for rule in DROP_RULES:
|
23 |
+
if rule.search(tag):
|
24 |
+
return True
|
25 |
+
return False
|
26 |
+
|
27 |
+
|
28 |
+
def get_group_and_title(tag):
|
29 |
+
for rule in GROUPING_RULES:
|
30 |
+
match = rule.search(tag)
|
31 |
+
if match is None:
|
32 |
+
continue
|
33 |
+
return match.group('group'), match.group('title')
|
34 |
+
return None, None
|
35 |
+
|
36 |
+
|
37 |
+
def main(args):
|
38 |
+
os.makedirs(args.outdir, exist_ok=True)
|
39 |
+
|
40 |
+
ignored_events = set()
|
41 |
+
|
42 |
+
for orig_fname in glob.glob(args.inglob):
|
43 |
+
cur_dirpath = os.path.dirname(orig_fname) # remove filename, this should point to "version_0" directory
|
44 |
+
subdirname = os.path.basename(cur_dirpath) # == "version_0" most of time
|
45 |
+
exp_root_path = os.path.dirname(cur_dirpath) # remove "version_0"
|
46 |
+
exp_name = os.path.basename(exp_root_path)
|
47 |
+
|
48 |
+
writers_by_group = {}
|
49 |
+
|
50 |
+
for e in tf.compat.v1.train.summary_iterator(orig_fname):
|
51 |
+
for v in e.summary.value:
|
52 |
+
if need_drop(v.tag):
|
53 |
+
continue
|
54 |
+
|
55 |
+
cur_group, cur_title = get_group_and_title(v.tag)
|
56 |
+
if cur_group is None:
|
57 |
+
if v.tag not in ignored_events:
|
58 |
+
print(f'WARNING: Could not detect group for {v.tag}, ignoring it')
|
59 |
+
ignored_events.add(v.tag)
|
60 |
+
continue
|
61 |
+
|
62 |
+
cur_writer = writers_by_group.get(cur_group, None)
|
63 |
+
if cur_writer is None:
|
64 |
+
if args.include_version:
|
65 |
+
cur_outdir = os.path.join(args.outdir, exp_name, f'{subdirname}_{cur_group}')
|
66 |
+
else:
|
67 |
+
cur_outdir = os.path.join(args.outdir, exp_name, cur_group)
|
68 |
+
cur_writer = SummaryWriter(cur_outdir)
|
69 |
+
writers_by_group[cur_group] = cur_writer
|
70 |
+
|
71 |
+
cur_writer.add_scalar(cur_title, v.simple_value, global_step=e.step, walltime=e.wall_time)
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == '__main__':
|
75 |
+
import argparse
|
76 |
+
|
77 |
+
aparser = argparse.ArgumentParser()
|
78 |
+
aparser.add_argument('inglob', type=str)
|
79 |
+
aparser.add_argument('outdir', type=str)
|
80 |
+
aparser.add_argument('--include-version', action='store_true',
|
81 |
+
help='Include subdirectory name e.g. "version_0" into output path')
|
82 |
+
|
83 |
+
main(aparser.parse_args())
|
bin/sample_from_dataset.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import tqdm
|
7 |
+
from skimage import io
|
8 |
+
from skimage.segmentation import mark_boundaries
|
9 |
+
|
10 |
+
from saicinpainting.evaluation.data import InpaintingDataset
|
11 |
+
from saicinpainting.evaluation.vis import save_item_for_vis
|
12 |
+
|
13 |
+
def save_mask_for_sidebyside(item, out_file):
|
14 |
+
mask = item['mask']# > 0.5
|
15 |
+
if mask.ndim == 3:
|
16 |
+
mask = mask[0]
|
17 |
+
mask = np.clip(mask * 255, 0, 255).astype('uint8')
|
18 |
+
io.imsave(out_file, mask)
|
19 |
+
|
20 |
+
def save_img_for_sidebyside(item, out_file):
|
21 |
+
img = np.transpose(item['image'], (1, 2, 0))
|
22 |
+
img = np.clip(img * 255, 0, 255).astype('uint8')
|
23 |
+
io.imsave(out_file, img)
|
24 |
+
|
25 |
+
def save_masked_img_for_sidebyside(item, out_file):
|
26 |
+
mask = item['mask']
|
27 |
+
img = item['image']
|
28 |
+
|
29 |
+
img = (1-mask) * img + mask
|
30 |
+
img = np.transpose(img, (1, 2, 0))
|
31 |
+
|
32 |
+
img = np.clip(img * 255, 0, 255).astype('uint8')
|
33 |
+
io.imsave(out_file, img)
|
34 |
+
|
35 |
+
def main(args):
|
36 |
+
dataset = InpaintingDataset(args.datadir, img_suffix='.png')
|
37 |
+
|
38 |
+
area_bins = np.linspace(0, 1, args.area_bins + 1)
|
39 |
+
|
40 |
+
heights = []
|
41 |
+
widths = []
|
42 |
+
image_areas = []
|
43 |
+
hole_areas = []
|
44 |
+
hole_area_percents = []
|
45 |
+
area_bins_count = np.zeros(args.area_bins)
|
46 |
+
area_bin_titles = [f'{area_bins[i] * 100:.0f}-{area_bins[i + 1] * 100:.0f}' for i in range(args.area_bins)]
|
47 |
+
|
48 |
+
bin2i = [[] for _ in range(args.area_bins)]
|
49 |
+
|
50 |
+
for i, item in enumerate(tqdm.tqdm(dataset)):
|
51 |
+
h, w = item['image'].shape[1:]
|
52 |
+
heights.append(h)
|
53 |
+
widths.append(w)
|
54 |
+
full_area = h * w
|
55 |
+
image_areas.append(full_area)
|
56 |
+
hole_area = (item['mask'] == 1).sum()
|
57 |
+
hole_areas.append(hole_area)
|
58 |
+
hole_percent = hole_area / full_area
|
59 |
+
hole_area_percents.append(hole_percent)
|
60 |
+
bin_i = np.clip(np.searchsorted(area_bins, hole_percent) - 1, 0, len(area_bins_count) - 1)
|
61 |
+
area_bins_count[bin_i] += 1
|
62 |
+
bin2i[bin_i].append(i)
|
63 |
+
|
64 |
+
os.makedirs(args.outdir, exist_ok=True)
|
65 |
+
|
66 |
+
for bin_i in range(args.area_bins):
|
67 |
+
bindir = os.path.join(args.outdir, area_bin_titles[bin_i])
|
68 |
+
os.makedirs(bindir, exist_ok=True)
|
69 |
+
bin_idx = bin2i[bin_i]
|
70 |
+
for sample_i in np.random.choice(bin_idx, size=min(len(bin_idx), args.samples_n), replace=False):
|
71 |
+
item = dataset[sample_i]
|
72 |
+
path = os.path.join(bindir, dataset.img_filenames[sample_i].split('/')[-1])
|
73 |
+
save_masked_img_for_sidebyside(item, path)
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__':
|
77 |
+
import argparse
|
78 |
+
|
79 |
+
aparser = argparse.ArgumentParser()
|
80 |
+
aparser.add_argument('--datadir', type=str,
|
81 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
82 |
+
aparser.add_argument('--outdir', type=str, help='Where to put results')
|
83 |
+
aparser.add_argument('--samples-n', type=int, default=10,
|
84 |
+
help='Number of sample images with masks to copy for visualization for each area bin')
|
85 |
+
aparser.add_argument('--area-bins', type=int, default=10, help='How many area bins to have')
|
86 |
+
|
87 |
+
main(aparser.parse_args())
|
bin/side_by_side.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
|
9 |
+
from saicinpainting.evaluation.utils import load_yaml
|
10 |
+
from saicinpainting.training.visualizers.base import visualize_mask_and_images
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
config = load_yaml(args.config)
|
15 |
+
|
16 |
+
datasets = [PrecomputedInpaintingResultsDataset(args.datadir, cur_predictdir, **config.dataset_kwargs)
|
17 |
+
for cur_predictdir in args.predictdirs]
|
18 |
+
assert len({len(ds) for ds in datasets}) == 1
|
19 |
+
len_first = len(datasets[0])
|
20 |
+
|
21 |
+
indices = list(range(len_first))
|
22 |
+
if len_first > args.max_n:
|
23 |
+
indices = sorted(random.sample(indices, args.max_n))
|
24 |
+
|
25 |
+
os.makedirs(args.outpath, exist_ok=True)
|
26 |
+
|
27 |
+
filename2i = {}
|
28 |
+
|
29 |
+
keys = ['image'] + [i for i in range(len(datasets))]
|
30 |
+
for img_i in indices:
|
31 |
+
try:
|
32 |
+
mask_fname = os.path.basename(datasets[0].mask_filenames[img_i])
|
33 |
+
if mask_fname in filename2i:
|
34 |
+
filename2i[mask_fname] += 1
|
35 |
+
idx = filename2i[mask_fname]
|
36 |
+
mask_fname_only, ext = os.path.split(mask_fname)
|
37 |
+
mask_fname = f'{mask_fname_only}_{idx}{ext}'
|
38 |
+
else:
|
39 |
+
filename2i[mask_fname] = 1
|
40 |
+
|
41 |
+
cur_vis_dict = datasets[0][img_i]
|
42 |
+
for ds_i, ds in enumerate(datasets):
|
43 |
+
cur_vis_dict[ds_i] = ds[img_i]['inpainted']
|
44 |
+
|
45 |
+
vis_img = visualize_mask_and_images(cur_vis_dict, keys,
|
46 |
+
last_without_mask=False,
|
47 |
+
mask_only_first=True,
|
48 |
+
black_mask=args.black)
|
49 |
+
vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8')
|
50 |
+
|
51 |
+
out_fname = os.path.join(args.outpath, mask_fname)
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)
|
56 |
+
cv2.imwrite(out_fname, vis_img)
|
57 |
+
except Exception as ex:
|
58 |
+
print(f'Could not process {img_i} due to {ex}')
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == '__main__':
|
62 |
+
import argparse
|
63 |
+
|
64 |
+
aparser = argparse.ArgumentParser()
|
65 |
+
aparser.add_argument('--max-n', type=int, default=100, help='Maximum number of images to print')
|
66 |
+
aparser.add_argument('--black', action='store_true', help='Whether to fill mask on GT with black')
|
67 |
+
aparser.add_argument('config', type=str, help='Path to evaluation config (e.g. configs/eval1.yaml)')
|
68 |
+
aparser.add_argument('outpath', type=str, help='Where to put results')
|
69 |
+
aparser.add_argument('datadir', type=str,
|
70 |
+
help='Path to folder with images and masks')
|
71 |
+
aparser.add_argument('predictdirs', type=str,
|
72 |
+
nargs='+',
|
73 |
+
help='Path to folders with predicts')
|
74 |
+
|
75 |
+
|
76 |
+
main(aparser.parse_args())
|
bin/split_tar.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
|
4 |
+
import tqdm
|
5 |
+
import webdataset as wds
|
6 |
+
|
7 |
+
|
8 |
+
def main(args):
|
9 |
+
input_dataset = wds.Dataset(args.infile)
|
10 |
+
output_dataset = wds.ShardWriter(args.outpattern)
|
11 |
+
for rec in tqdm.tqdm(input_dataset):
|
12 |
+
output_dataset.write(rec)
|
13 |
+
|
14 |
+
|
15 |
+
if __name__ == '__main__':
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
aparser = argparse.ArgumentParser()
|
19 |
+
aparser.add_argument('infile', type=str)
|
20 |
+
aparser.add_argument('outpattern', type=str)
|
21 |
+
|
22 |
+
main(aparser.parse_args())
|
bin/train.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import traceback
|
7 |
+
|
8 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
9 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
10 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
11 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
12 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
13 |
+
|
14 |
+
import hydra
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
from pytorch_lightning import Trainer
|
17 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
18 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
19 |
+
from pytorch_lightning.plugins import DDPPlugin
|
20 |
+
|
21 |
+
from saicinpainting.training.trainers import make_training_model
|
22 |
+
from saicinpainting.utils import register_debug_signal_handlers, handle_ddp_subprocess, handle_ddp_parent_process, \
|
23 |
+
handle_deterministic_config
|
24 |
+
|
25 |
+
LOGGER = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
@handle_ddp_subprocess()
|
29 |
+
@hydra.main(config_path='../configs/training', config_name='tiny_test.yaml')
|
30 |
+
def main(config: OmegaConf):
|
31 |
+
try:
|
32 |
+
need_set_deterministic = handle_deterministic_config(config)
|
33 |
+
|
34 |
+
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
|
35 |
+
|
36 |
+
is_in_ddp_subprocess = handle_ddp_parent_process()
|
37 |
+
|
38 |
+
config.visualizer.outdir = os.path.join(os.getcwd(), config.visualizer.outdir)
|
39 |
+
if not is_in_ddp_subprocess:
|
40 |
+
LOGGER.info(OmegaConf.to_yaml(config))
|
41 |
+
OmegaConf.save(config, os.path.join(os.getcwd(), 'config.yaml'))
|
42 |
+
|
43 |
+
checkpoints_dir = os.path.join(os.getcwd(), 'models')
|
44 |
+
os.makedirs(checkpoints_dir, exist_ok=True)
|
45 |
+
|
46 |
+
# there is no need to suppress this logger in ddp, because it handles rank on its own
|
47 |
+
metrics_logger = TensorBoardLogger(config.location.tb_dir, name=os.path.basename(os.getcwd()))
|
48 |
+
metrics_logger.log_hyperparams(config)
|
49 |
+
|
50 |
+
training_model = make_training_model(config)
|
51 |
+
|
52 |
+
trainer_kwargs = OmegaConf.to_container(config.trainer.kwargs, resolve=True)
|
53 |
+
if need_set_deterministic:
|
54 |
+
trainer_kwargs['deterministic'] = True
|
55 |
+
|
56 |
+
trainer = Trainer(
|
57 |
+
# there is no need to suppress checkpointing in ddp, because it handles rank on its own
|
58 |
+
callbacks=ModelCheckpoint(dirpath=checkpoints_dir, **config.trainer.checkpoint_kwargs),
|
59 |
+
logger=metrics_logger,
|
60 |
+
default_root_dir=os.getcwd(),
|
61 |
+
**trainer_kwargs
|
62 |
+
)
|
63 |
+
trainer.fit(training_model)
|
64 |
+
except KeyboardInterrupt:
|
65 |
+
LOGGER.warning('Interrupted by user')
|
66 |
+
except Exception as ex:
|
67 |
+
LOGGER.critical(f'Training failed due to {ex}:\n{traceback.format_exc()}')
|
68 |
+
sys.exit(1)
|
69 |
+
|
70 |
+
|
71 |
+
if __name__ == '__main__':
|
72 |
+
main()
|
canvas.png
ADDED
![]() |
conda_env.yml
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: lama
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
- conda-forge
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- _openmp_mutex=4.5=1_gnu
|
8 |
+
- absl-py=0.13.0=py36h06a4308_0
|
9 |
+
- aiohttp=3.7.4.post0=py36h7f8727e_2
|
10 |
+
- antlr-python-runtime=4.8=py36h9f0ad1d_2
|
11 |
+
- async-timeout=3.0.1=py36h06a4308_0
|
12 |
+
- attrs=21.2.0=pyhd3eb1b0_0
|
13 |
+
- blas=1.0=mkl
|
14 |
+
- blinker=1.4=py36h06a4308_0
|
15 |
+
- brotlipy=0.7.0=py36h27cfd23_1003
|
16 |
+
- bzip2=1.0.8=h7b6447c_0
|
17 |
+
- c-ares=1.17.1=h27cfd23_0
|
18 |
+
- ca-certificates=2021.7.5=h06a4308_1
|
19 |
+
- cachetools=4.2.2=pyhd3eb1b0_0
|
20 |
+
- certifi=2021.5.30=py36h06a4308_0
|
21 |
+
- cffi=1.14.6=py36h400218f_0
|
22 |
+
- chardet=4.0.0=py36h06a4308_1003
|
23 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
24 |
+
- click=8.0.1=pyhd3eb1b0_0
|
25 |
+
- cloudpickle=2.0.0=pyhd3eb1b0_0
|
26 |
+
- coverage=5.5=py36h27cfd23_2
|
27 |
+
- cryptography=3.4.7=py36hd23ed53_0
|
28 |
+
- cudatoolkit=10.2.89=hfd86e86_1
|
29 |
+
- cycler=0.10.0=py36_0
|
30 |
+
- cython=0.29.24=py36h295c915_0
|
31 |
+
- cytoolz=0.11.0=py36h7b6447c_0
|
32 |
+
- dask-core=1.1.4=py36_1
|
33 |
+
- dataclasses=0.8=pyh4f3eec9_6
|
34 |
+
- dbus=1.13.18=hb2f20db_0
|
35 |
+
- decorator=5.0.9=pyhd3eb1b0_0
|
36 |
+
- easydict=1.9=py_0
|
37 |
+
- expat=2.4.1=h2531618_2
|
38 |
+
- ffmpeg=4.2.2=h20bf706_0
|
39 |
+
- fontconfig=2.13.1=h6c09931_0
|
40 |
+
- freetype=2.10.4=h5ab3b9f_0
|
41 |
+
- fsspec=2021.8.1=pyhd3eb1b0_0
|
42 |
+
- future=0.18.2=py36_1
|
43 |
+
- glib=2.69.1=h5202010_0
|
44 |
+
- gmp=6.2.1=h2531618_2
|
45 |
+
- gnutls=3.6.15=he1e5248_0
|
46 |
+
- google-auth=1.33.0=pyhd3eb1b0_0
|
47 |
+
- google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
|
48 |
+
- grpcio=1.36.1=py36h2157cd5_1
|
49 |
+
- gst-plugins-base=1.14.0=h8213a91_2
|
50 |
+
- gstreamer=1.14.0=h28cd5cc_2
|
51 |
+
- hydra-core=1.1.0=pyhd8ed1ab_0
|
52 |
+
- icu=58.2=he6710b0_3
|
53 |
+
- idna=3.2=pyhd3eb1b0_0
|
54 |
+
- idna_ssl=1.1.0=py36h06a4308_0
|
55 |
+
- imageio=2.9.0=pyhd3eb1b0_0
|
56 |
+
- importlib-metadata=4.8.1=py36h06a4308_0
|
57 |
+
- importlib_resources=5.2.0=pyhd3eb1b0_1
|
58 |
+
- intel-openmp=2021.3.0=h06a4308_3350
|
59 |
+
- joblib=1.0.1=pyhd3eb1b0_0
|
60 |
+
- jpeg=9b=h024ee3a_2
|
61 |
+
- kiwisolver=1.3.1=py36h2531618_0
|
62 |
+
- lame=3.100=h7b6447c_0
|
63 |
+
- lcms2=2.12=h3be6417_0
|
64 |
+
- ld_impl_linux-64=2.35.1=h7274673_9
|
65 |
+
- libblas=3.9.0=11_linux64_mkl
|
66 |
+
- libcblas=3.9.0=11_linux64_mkl
|
67 |
+
- libffi=3.3=he6710b0_2
|
68 |
+
- libgcc-ng=9.3.0=h5101ec6_17
|
69 |
+
- libgfortran-ng=9.3.0=ha5ec8a7_17
|
70 |
+
- libgfortran5=9.3.0=ha5ec8a7_17
|
71 |
+
- libgomp=9.3.0=h5101ec6_17
|
72 |
+
- libidn2=2.3.2=h7f8727e_0
|
73 |
+
- liblapack=3.9.0=11_linux64_mkl
|
74 |
+
- libopus=1.3.1=h7b6447c_0
|
75 |
+
- libpng=1.6.37=hbc83047_0
|
76 |
+
- libprotobuf=3.17.2=h4ff587b_1
|
77 |
+
- libstdcxx-ng=9.3.0=hd4cf53a_17
|
78 |
+
- libtasn1=4.16.0=h27cfd23_0
|
79 |
+
- libtiff=4.2.0=h85742a9_0
|
80 |
+
- libunistring=0.9.10=h27cfd23_0
|
81 |
+
- libuuid=1.0.3=h1bed415_2
|
82 |
+
- libuv=1.40.0=h7b6447c_0
|
83 |
+
- libvpx=1.7.0=h439df22_0
|
84 |
+
- libwebp-base=1.2.0=h27cfd23_0
|
85 |
+
- libxcb=1.14=h7b6447c_0
|
86 |
+
- libxml2=2.9.12=h03d6c58_0
|
87 |
+
- lz4-c=1.9.3=h295c915_1
|
88 |
+
- markdown=3.3.4=py36h06a4308_0
|
89 |
+
- matplotlib=3.3.4=py36h06a4308_0
|
90 |
+
- matplotlib-base=3.3.4=py36h62a2d02_0
|
91 |
+
- mkl=2021.3.0=h06a4308_520
|
92 |
+
- multidict=5.1.0=py36h27cfd23_2
|
93 |
+
- ncurses=6.2=he6710b0_1
|
94 |
+
- nettle=3.7.3=hbbd107a_1
|
95 |
+
- networkx=2.2=py36_1
|
96 |
+
- ninja=1.10.2=hff7bd54_1
|
97 |
+
- numpy=1.19.5=py36hfc0c790_2
|
98 |
+
- oauthlib=3.1.1=pyhd3eb1b0_0
|
99 |
+
- olefile=0.46=py36_0
|
100 |
+
- omegaconf=2.1.1=py36h5fab9bb_0
|
101 |
+
- openh264=2.1.0=hd408876_0
|
102 |
+
- openjpeg=2.4.0=h3ad879b_0
|
103 |
+
- openssl=1.1.1l=h7f8727e_0
|
104 |
+
- packaging=21.0=pyhd3eb1b0_0
|
105 |
+
- pandas=1.1.5=py36h284efc9_0
|
106 |
+
- pcre=8.45=h295c915_0
|
107 |
+
- pillow=8.3.1=py36h2c7a002_0
|
108 |
+
- pip=21.0.1=py36h06a4308_0
|
109 |
+
- protobuf=3.17.2=py36h295c915_0
|
110 |
+
- pyasn1=0.4.8=pyhd3eb1b0_0
|
111 |
+
- pyasn1-modules=0.2.8=py_0
|
112 |
+
- pycparser=2.20=py_2
|
113 |
+
- pyjwt=2.1.0=py36h06a4308_0
|
114 |
+
- pyopenssl=20.0.1=pyhd3eb1b0_1
|
115 |
+
- pyparsing=2.4.7=pyhd3eb1b0_0
|
116 |
+
- pyqt=5.9.2=py36h05f1152_2
|
117 |
+
- pysocks=1.7.1=py36h06a4308_0
|
118 |
+
- python=3.6.13=h12debd9_1
|
119 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
120 |
+
- python_abi=3.6=2_cp36m
|
121 |
+
- pytz=2021.1=pyhd3eb1b0_0
|
122 |
+
- pywavelets=1.1.1=py36h7b6447c_2
|
123 |
+
- pyyaml=5.4.1=py36h27cfd23_1
|
124 |
+
- qt=5.9.7=h5867ecd_1
|
125 |
+
- readline=8.1=h27cfd23_0
|
126 |
+
- requests=2.26.0=pyhd3eb1b0_0
|
127 |
+
- requests-oauthlib=1.3.0=py_0
|
128 |
+
- rsa=4.7.2=pyhd3eb1b0_1
|
129 |
+
- scikit-image=0.17.2=py36h284efc9_4
|
130 |
+
- scikit-learn=0.24.2=py36ha9443f7_0
|
131 |
+
- scipy=1.5.3=py36h9e8f40b_0
|
132 |
+
- setuptools=58.0.4=py36h06a4308_0
|
133 |
+
- sip=4.19.8=py36hf484d3e_0
|
134 |
+
- six=1.16.0=pyhd3eb1b0_0
|
135 |
+
- sqlite=3.36.0=hc218d9a_0
|
136 |
+
- tabulate=0.8.9=py36h06a4308_0
|
137 |
+
- tensorboard=2.4.0=pyhc547734_0
|
138 |
+
- tensorboard-plugin-wit=1.6.0=py_0
|
139 |
+
- threadpoolctl=2.2.0=pyh0d69192_0
|
140 |
+
- tifffile=2020.10.1=py36hdd07704_2
|
141 |
+
- tk=8.6.11=h1ccaba5_0
|
142 |
+
- toolz=0.11.1=pyhd3eb1b0_0
|
143 |
+
- tqdm=4.62.2=pyhd3eb1b0_1
|
144 |
+
- typing-extensions=3.10.0.2=hd3eb1b0_0
|
145 |
+
- typing_extensions=3.10.0.2=pyh06a4308_0
|
146 |
+
- urllib3=1.26.6=pyhd3eb1b0_1
|
147 |
+
- werkzeug=2.0.1=pyhd3eb1b0_0
|
148 |
+
- wheel=0.37.0=pyhd3eb1b0_1
|
149 |
+
- x264=1!157.20191217=h7b6447c_0
|
150 |
+
- xz=5.2.5=h7b6447c_0
|
151 |
+
- yaml=0.2.5=h7b6447c_0
|
152 |
+
- yarl=1.6.3=py36h27cfd23_0
|
153 |
+
- zipp=3.5.0=pyhd3eb1b0_0
|
154 |
+
- zlib=1.2.11=h7b6447c_3
|
155 |
+
- zstd=1.4.9=haebb681_0
|
156 |
+
- pip:
|
157 |
+
- albumentations==0.5.2
|
158 |
+
- braceexpand==0.1.7
|
159 |
+
- imgaug==0.4.0
|
160 |
+
- kornia==0.5.0
|
161 |
+
- opencv-python==4.5.3.56
|
162 |
+
- opencv-python-headless==4.5.3.56
|
163 |
+
- shapely==1.7.1
|
164 |
+
- webdataset==0.1.76
|
165 |
+
- wldhx-yadisk-direct==0.0.6
|
configs/analyze_mask_errors.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_kwargs:
|
2 |
+
img_suffix: .jpg
|
3 |
+
inpainted_suffix: .jpg
|
4 |
+
|
5 |
+
take_global_top: 30
|
6 |
+
take_worst_best_top: 30
|
7 |
+
take_overlapping_top: 30
|
configs/data_gen/gen_segm_dataset1.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: segmentation
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
confidence_threshold: 0.5
|
5 |
+
max_object_area: 0.5
|
6 |
+
min_mask_area: 0.02
|
7 |
+
downsample_levels: 6
|
8 |
+
num_variants_per_mask: 5
|
9 |
+
rigidness_mode: 1
|
10 |
+
max_foreground_coverage: 0.3
|
11 |
+
max_foreground_intersection: 0.7
|
12 |
+
max_mask_intersection: 0.1
|
13 |
+
max_hidden_area: 0.1
|
14 |
+
max_scale_change: 0.25
|
15 |
+
horizontal_flip: True
|
16 |
+
max_vertical_shift: 0.2
|
17 |
+
position_shuffle: True
|
18 |
+
|
19 |
+
max_masks_per_image: 5
|
20 |
+
|
21 |
+
cropping:
|
22 |
+
out_min_size: 512
|
23 |
+
handle_small_mode: drop
|
24 |
+
out_square_crop: True
|
25 |
+
crop_min_overlap: 0.5
|
configs/data_gen/gen_segm_dataset3.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: segmentation
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
confidence_threshold: 0.5
|
5 |
+
max_object_area: 0.5
|
6 |
+
min_mask_area: 0.07
|
7 |
+
downsample_levels: 6
|
8 |
+
num_variants_per_mask: 3
|
9 |
+
rigidness_mode: 1
|
10 |
+
max_foreground_coverage: 0.4
|
11 |
+
max_foreground_intersection: 0.8
|
12 |
+
max_mask_intersection: 0.2
|
13 |
+
max_hidden_area: 0.1
|
14 |
+
max_scale_change: 0.25
|
15 |
+
horizontal_flip: True
|
16 |
+
max_vertical_shift: 0.3
|
17 |
+
position_shuffle: True
|
18 |
+
|
19 |
+
max_masks_per_image: 3
|
20 |
+
|
21 |
+
cropping:
|
22 |
+
out_min_size: 512
|
23 |
+
handle_small_mode: drop
|
24 |
+
out_square_crop: True
|
25 |
+
crop_min_overlap: 0.5
|
configs/data_gen/random_medium_256.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 4
|
7 |
+
max_times: 5
|
8 |
+
max_width: 50
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 100
|
11 |
+
|
12 |
+
box_proba: 0.3
|
13 |
+
box_kwargs:
|
14 |
+
margin: 0
|
15 |
+
bbox_min_size: 10
|
16 |
+
bbox_max_size: 50
|
17 |
+
max_times: 5
|
18 |
+
min_times: 1
|
19 |
+
|
20 |
+
segm_proba: 0
|
21 |
+
squares_proba: 0
|
22 |
+
|
23 |
+
variants_n: 5
|
24 |
+
|
25 |
+
max_masks_per_image: 1
|
26 |
+
|
27 |
+
cropping:
|
28 |
+
out_min_size: 256
|
29 |
+
handle_small_mode: upscale
|
30 |
+
out_square_crop: True
|
31 |
+
crop_min_overlap: 1
|
32 |
+
|
33 |
+
max_tamper_area: 0.5
|
configs/data_gen/random_medium_512.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 4
|
7 |
+
max_times: 10
|
8 |
+
max_width: 100
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 200
|
11 |
+
|
12 |
+
box_proba: 0.3
|
13 |
+
box_kwargs:
|
14 |
+
margin: 0
|
15 |
+
bbox_min_size: 30
|
16 |
+
bbox_max_size: 150
|
17 |
+
max_times: 5
|
18 |
+
min_times: 1
|
19 |
+
|
20 |
+
segm_proba: 0
|
21 |
+
squares_proba: 0
|
22 |
+
|
23 |
+
variants_n: 5
|
24 |
+
|
25 |
+
max_masks_per_image: 1
|
26 |
+
|
27 |
+
cropping:
|
28 |
+
out_min_size: 512
|
29 |
+
handle_small_mode: upscale
|
30 |
+
out_square_crop: True
|
31 |
+
crop_min_overlap: 1
|
32 |
+
|
33 |
+
max_tamper_area: 0.5
|
configs/data_gen/random_thick_256.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 1
|
7 |
+
max_times: 5
|
8 |
+
max_width: 100
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 200
|
11 |
+
|
12 |
+
box_proba: 0.3
|
13 |
+
box_kwargs:
|
14 |
+
margin: 10
|
15 |
+
bbox_min_size: 30
|
16 |
+
bbox_max_size: 150
|
17 |
+
max_times: 3
|
18 |
+
min_times: 1
|
19 |
+
|
20 |
+
segm_proba: 0
|
21 |
+
squares_proba: 0
|
22 |
+
|
23 |
+
variants_n: 5
|
24 |
+
|
25 |
+
max_masks_per_image: 1
|
26 |
+
|
27 |
+
cropping:
|
28 |
+
out_min_size: 256
|
29 |
+
handle_small_mode: upscale
|
30 |
+
out_square_crop: True
|
31 |
+
crop_min_overlap: 1
|
32 |
+
|
33 |
+
max_tamper_area: 0.5
|
configs/data_gen/random_thick_512.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 1
|
7 |
+
max_times: 5
|
8 |
+
max_width: 250
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 450
|
11 |
+
|
12 |
+
box_proba: 0.3
|
13 |
+
box_kwargs:
|
14 |
+
margin: 10
|
15 |
+
bbox_min_size: 30
|
16 |
+
bbox_max_size: 300
|
17 |
+
max_times: 4
|
18 |
+
min_times: 1
|
19 |
+
|
20 |
+
segm_proba: 0
|
21 |
+
squares_proba: 0
|
22 |
+
|
23 |
+
variants_n: 5
|
24 |
+
|
25 |
+
max_masks_per_image: 1
|
26 |
+
|
27 |
+
cropping:
|
28 |
+
out_min_size: 512
|
29 |
+
handle_small_mode: upscale
|
30 |
+
out_square_crop: True
|
31 |
+
crop_min_overlap: 1
|
32 |
+
|
33 |
+
max_tamper_area: 0.5
|
configs/data_gen/random_thin_256.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 4
|
7 |
+
max_times: 50
|
8 |
+
max_width: 10
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 40
|
11 |
+
box_proba: 0
|
12 |
+
segm_proba: 0
|
13 |
+
squares_proba: 0
|
14 |
+
|
15 |
+
variants_n: 5
|
16 |
+
|
17 |
+
max_masks_per_image: 1
|
18 |
+
|
19 |
+
cropping:
|
20 |
+
out_min_size: 256
|
21 |
+
handle_small_mode: upscale
|
22 |
+
out_square_crop: True
|
23 |
+
crop_min_overlap: 1
|
24 |
+
|
25 |
+
max_tamper_area: 0.5
|
configs/data_gen/random_thin_512.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 4
|
7 |
+
max_times: 70
|
8 |
+
max_width: 20
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 100
|
11 |
+
box_proba: 0
|
12 |
+
segm_proba: 0
|
13 |
+
squares_proba: 0
|
14 |
+
|
15 |
+
variants_n: 5
|
16 |
+
|
17 |
+
max_masks_per_image: 1
|
18 |
+
|
19 |
+
cropping:
|
20 |
+
out_min_size: 512
|
21 |
+
handle_small_mode: upscale
|
22 |
+
out_square_crop: True
|
23 |
+
crop_min_overlap: 1
|
24 |
+
|
25 |
+
max_tamper_area: 0.5
|
configs/data_gen/segm_256.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: segmentation
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
confidence_threshold: 0.5
|
5 |
+
max_object_area: 0.5
|
6 |
+
min_mask_area: 0.05
|
7 |
+
downsample_levels: 6
|
8 |
+
num_variants_per_mask: 3
|
9 |
+
rigidness_mode: 1
|
10 |
+
max_foreground_coverage: 1 # turn off filtering by overlap
|
11 |
+
max_foreground_intersection: 1 # turn off filtering by overlap
|
12 |
+
max_mask_intersection: 0.2 # the lower this value the higher diversity
|
13 |
+
max_hidden_area: 0.5
|
14 |
+
max_scale_change: 0.25
|
15 |
+
horizontal_flip: True
|
16 |
+
max_vertical_shift: 0.3
|
17 |
+
position_shuffle: True
|
18 |
+
|
19 |
+
max_masks_per_image: 1
|
20 |
+
|
21 |
+
cropping:
|
22 |
+
out_min_size: 256
|
23 |
+
handle_small_mode: upscale
|
24 |
+
out_square_crop: True
|
25 |
+
crop_min_overlap: 1
|
26 |
+
|
27 |
+
max_tamper_area: 0.5
|
configs/data_gen/segm_512.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: segmentation
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
confidence_threshold: 0.5
|
5 |
+
max_object_area: 0.5
|
6 |
+
min_mask_area: 0.05
|
7 |
+
downsample_levels: 6
|
8 |
+
num_variants_per_mask: 3
|
9 |
+
rigidness_mode: 1
|
10 |
+
max_foreground_coverage: 1 # turn off filtering by overlap
|
11 |
+
max_foreground_intersection: 1 # turn off filtering by overlap
|
12 |
+
max_mask_intersection: 0.2 # the lower this value the higher diversity
|
13 |
+
max_hidden_area: 0.5
|
14 |
+
max_scale_change: 0.25
|
15 |
+
horizontal_flip: True
|
16 |
+
max_vertical_shift: 0.3
|
17 |
+
position_shuffle: True
|
18 |
+
|
19 |
+
max_masks_per_image: 1
|
20 |
+
|
21 |
+
cropping:
|
22 |
+
out_min_size: 512
|
23 |
+
handle_small_mode: upscale
|
24 |
+
out_square_crop: True
|
25 |
+
crop_min_overlap: 1
|
26 |
+
|
27 |
+
max_tamper_area: 0.5
|
configs/data_gen/sr_256.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 0
|
5 |
+
box_proba: 0
|
6 |
+
segm_proba: 0
|
7 |
+
squares_proba: 0
|
8 |
+
superres_proba: 1
|
9 |
+
superres_kwargs:
|
10 |
+
min_step: 2
|
11 |
+
max_step: 4
|
12 |
+
min_width: 1
|
13 |
+
max_width: 3
|
14 |
+
|
15 |
+
variants_n: 5
|
16 |
+
|
17 |
+
max_masks_per_image: 1
|
18 |
+
|
19 |
+
cropping:
|
20 |
+
out_min_size: 256
|
21 |
+
handle_small_mode: upscale
|
22 |
+
out_square_crop: True
|
23 |
+
crop_min_overlap: 1
|
24 |
+
|
25 |
+
max_tamper_area: 1
|
configs/data_gen/whydra/location/mml-ws01-celeba-hq.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
root_dir: /media/inpainting/CelebA-HQ
|
4 |
+
out_dir: /media/inpainting/paper_data/CelebA-HQ_val_test
|
5 |
+
extension: jpg
|
configs/data_gen/whydra/location/mml-ws01-ffhq.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
root_dir: /media/inpainting/FFHQ/
|
4 |
+
out_dir: /media/inpainting/paper_data/FFHQ_val
|
5 |
+
extension: png
|
configs/data_gen/whydra/location/mml-ws01-paris.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
root_dir: /media/inpainting/Paris_StreetView_Dataset
|
4 |
+
out_dir: /media/inpainting/paper_data/Paris_StreetView_Dataset_val
|
5 |
+
extension: png
|
configs/data_gen/whydra/location/mml7-places.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
root_dir: /data/inpainting/Places365
|
4 |
+
out_dir: /data/inpainting/paper_data/Places365_val_test
|
5 |
+
extension: jpg
|