Update saicinpainting/training/trainers/default.py
Browse files
saicinpainting/training/trainers/default.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import logging
|
2 |
-
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
from omegaconf import OmegaConf
|
@@ -13,30 +13,6 @@ from saicinpainting.utils import add_prefix_to_keys, get_ramp
|
|
13 |
|
14 |
LOGGER = logging.getLogger(__name__)
|
15 |
|
16 |
-
def resize_to_square(image, target_size):
|
17 |
-
h, w = image.shape[:2]
|
18 |
-
if h == w:
|
19 |
-
return cv2.resize(image, (target_size, target_size))
|
20 |
-
|
21 |
-
dif = h if h > w else w
|
22 |
-
interpolation = cv2.INTER_AREA if dif > target_size else cv2.INTER_CUBIC
|
23 |
-
|
24 |
-
x_pos = (dif - w) // 2
|
25 |
-
y_pos = (dif - h) // 2
|
26 |
-
|
27 |
-
if len(image.shape) == 2:
|
28 |
-
mask = np.zeros((dif, dif), dtype=image.dtype)
|
29 |
-
mask[y_pos:y_pos+h, x_pos:x_pos+w] = image
|
30 |
-
else:
|
31 |
-
mask = np.zeros((dif, dif, image.shape[2]), dtype=image.dtype)
|
32 |
-
mask[y_pos:y_pos+h, x_pos:x_pos+w, :] = image
|
33 |
-
|
34 |
-
return cv2.resize(mask, (target_size, target_size), interpolation=interpolation)
|
35 |
-
|
36 |
-
# Sử dụng
|
37 |
-
target_size = 256
|
38 |
-
resized_frame = resize_to_square(frame, target_size)
|
39 |
-
|
40 |
|
41 |
def make_constant_area_crop_batch(batch, **kwargs):
|
42 |
crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
|
@@ -48,9 +24,25 @@ def make_constant_area_crop_batch(batch, **kwargs):
|
|
48 |
|
49 |
|
50 |
class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
|
51 |
-
def __init__(self, *args,
|
|
|
|
|
|
|
|
|
52 |
super().__init__(*args, **kwargs)
|
53 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
def forward(self, batch):
|
56 |
if self.training and self.rescale_size_getter is not None:
|
@@ -58,29 +50,6 @@ class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
|
|
58 |
batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
|
59 |
batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
|
60 |
|
61 |
-
# Thêm đoạn code resize ở đây
|
62 |
-
resized_images = []
|
63 |
-
resized_masks = []
|
64 |
-
for img, mask in zip(batch['image'], batch['mask']):
|
65 |
-
# Chuyển từ tensor sang numpy array
|
66 |
-
img_np = img.permute(1, 2, 0).cpu().numpy()
|
67 |
-
mask_np = mask.squeeze().cpu().numpy()
|
68 |
-
|
69 |
-
# Resize
|
70 |
-
img_resized = resize_to_square(img_np, self.target_size)
|
71 |
-
mask_resized = resize_to_square(mask_np, self.target_size)
|
72 |
-
|
73 |
-
# Chuyển lại thành tensor
|
74 |
-
img_resized = torch.from_numpy(img_resized).permute(2, 0, 1).float().to(img.device)
|
75 |
-
mask_resized = torch.from_numpy(mask_resized).unsqueeze(0).float().to(mask.device)
|
76 |
-
|
77 |
-
resized_images.append(img_resized)
|
78 |
-
resized_masks.append(mask_resized)
|
79 |
-
|
80 |
-
batch['image'] = torch.stack(resized_images)
|
81 |
-
batch['mask'] = torch.stack(resized_masks)
|
82 |
-
|
83 |
-
# Tiếp tục với phần còn lại của phương thức forward
|
84 |
if self.training and self.const_area_crop_kwargs is not None:
|
85 |
batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
|
86 |
|
@@ -203,4 +172,4 @@ class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
|
|
203 |
metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss
|
204 |
metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_'))
|
205 |
|
206 |
-
return total_loss, metrics
|
|
|
1 |
import logging
|
2 |
+
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
from omegaconf import OmegaConf
|
|
|
13 |
|
14 |
LOGGER = logging.getLogger(__name__)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
def make_constant_area_crop_batch(batch, **kwargs):
|
18 |
crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
|
|
|
24 |
|
25 |
|
26 |
class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
|
27 |
+
def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image',
|
28 |
+
add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None,
|
29 |
+
distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False,
|
30 |
+
fake_fakes_proba=0, fake_fakes_generator_kwargs=None,
|
31 |
+
**kwargs):
|
32 |
super().__init__(*args, **kwargs)
|
33 |
+
self.concat_mask = concat_mask
|
34 |
+
self.rescale_size_getter = get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None
|
35 |
+
self.image_to_discriminator = image_to_discriminator
|
36 |
+
self.add_noise_kwargs = add_noise_kwargs
|
37 |
+
self.noise_fill_hole = noise_fill_hole
|
38 |
+
self.const_area_crop_kwargs = const_area_crop_kwargs
|
39 |
+
self.refine_mask_for_losses = make_mask_distance_weighter(**distance_weighter_kwargs) \
|
40 |
+
if distance_weighter_kwargs is not None else None
|
41 |
+
self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
|
42 |
+
|
43 |
+
self.fake_fakes_proba = fake_fakes_proba
|
44 |
+
if self.fake_fakes_proba > 1e-3:
|
45 |
+
self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {}))
|
46 |
|
47 |
def forward(self, batch):
|
48 |
if self.training and self.rescale_size_getter is not None:
|
|
|
50 |
batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
|
51 |
batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
if self.training and self.const_area_crop_kwargs is not None:
|
54 |
batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
|
55 |
|
|
|
172 |
metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss
|
173 |
metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_'))
|
174 |
|
175 |
+
return total_loss, metrics
|