ozgurkara commited on
Commit
eb9a9b4
1 Parent(s): 9805cc2

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. annotator/__pycache__/annotator_path.cpython-38.pyc +0 -0
  2. annotator/__pycache__/util.cpython-38.pyc +0 -0
  3. annotator/annotator_path.py +17 -0
  4. annotator/binary/__init__.py +14 -0
  5. annotator/canny/__init__.py +5 -0
  6. annotator/canny/__pycache__/__init__.cpython-311.pyc +0 -0
  7. annotator/canny/__pycache__/__init__.cpython-38.pyc +0 -0
  8. annotator/canny/__pycache__/__init__.cpython-39.pyc +0 -0
  9. annotator/clipvision/__init__.py +127 -0
  10. annotator/clipvision/clip_vision_h_uc.data +0 -0
  11. annotator/color/__init__.py +20 -0
  12. annotator/hed/__init__.py +97 -0
  13. annotator/hed/__pycache__/__init__.cpython-311.pyc +0 -0
  14. annotator/hed/__pycache__/__init__.cpython-38.pyc +0 -0
  15. annotator/hed/__pycache__/__init__.cpython-39.pyc +0 -0
  16. annotator/keypose/__init__.py +212 -0
  17. annotator/keypose/faster_rcnn_r50_fpn_coco.py +182 -0
  18. annotator/keypose/hrnet_w48_coco_256x192.py +169 -0
  19. annotator/lama/__init__.py +58 -0
  20. annotator/lama/config.yaml +157 -0
  21. annotator/lama/saicinpainting/__init__.py +0 -0
  22. annotator/lama/saicinpainting/training/__init__.py +0 -0
  23. annotator/lama/saicinpainting/training/data/__init__.py +0 -0
  24. annotator/lama/saicinpainting/training/data/masks.py +332 -0
  25. annotator/lama/saicinpainting/training/losses/__init__.py +0 -0
  26. annotator/lama/saicinpainting/training/losses/adversarial.py +177 -0
  27. annotator/lama/saicinpainting/training/losses/constants.py +152 -0
  28. annotator/lama/saicinpainting/training/losses/distance_weighting.py +126 -0
  29. annotator/lama/saicinpainting/training/losses/feature_matching.py +33 -0
  30. annotator/lama/saicinpainting/training/losses/perceptual.py +113 -0
  31. annotator/lama/saicinpainting/training/losses/segmentation.py +43 -0
  32. annotator/lama/saicinpainting/training/losses/style_loss.py +155 -0
  33. annotator/lama/saicinpainting/training/modules/__init__.py +31 -0
  34. annotator/lama/saicinpainting/training/modules/base.py +80 -0
  35. annotator/lama/saicinpainting/training/modules/depthwise_sep_conv.py +17 -0
  36. annotator/lama/saicinpainting/training/modules/fake_fakes.py +47 -0
  37. annotator/lama/saicinpainting/training/modules/ffc.py +485 -0
  38. annotator/lama/saicinpainting/training/modules/multidilated_conv.py +98 -0
  39. annotator/lama/saicinpainting/training/modules/multiscale.py +244 -0
  40. annotator/lama/saicinpainting/training/modules/pix2pixhd.py +669 -0
  41. annotator/lama/saicinpainting/training/modules/spatial_transform.py +49 -0
  42. annotator/lama/saicinpainting/training/modules/squeeze_excitation.py +20 -0
  43. annotator/lama/saicinpainting/training/trainers/__init__.py +29 -0
  44. annotator/lama/saicinpainting/training/trainers/base.py +293 -0
  45. annotator/lama/saicinpainting/training/trainers/default.py +175 -0
  46. annotator/lama/saicinpainting/training/visualizers/__init__.py +15 -0
  47. annotator/lama/saicinpainting/training/visualizers/base.py +73 -0
  48. annotator/lama/saicinpainting/training/visualizers/colors.py +76 -0
  49. annotator/lama/saicinpainting/training/visualizers/directory.py +36 -0
  50. annotator/lama/saicinpainting/training/visualizers/noop.py +9 -0
annotator/__pycache__/annotator_path.cpython-38.pyc ADDED
Binary file (583 Bytes). View file
 
annotator/__pycache__/util.cpython-38.pyc ADDED
Binary file (2.09 kB). View file
 
annotator/annotator_path.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import utils.constants as const
4
+
5
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6
+ models_path = f'{const.CWD}/pretrained_models'
7
+
8
+ clip_vision_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision')
9
+ # clip vision is always inside controlnet "extensions\sd-webui-controlnet"
10
+ # and any problem can be solved by removing controlnet and reinstall
11
+
12
+ models_path = os.path.realpath(models_path)
13
+ os.makedirs(models_path, exist_ok=True)
14
+ print(f'ControlNet preprocessor location: {models_path}')
15
+ # Make sure that the default location is inside controlnet "extensions\sd-webui-controlnet"
16
+ # so that any problem can be solved by removing controlnet and reinstall
17
+ # if users do not change configs on their own (otherwise users will know what is wrong)
annotator/binary/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+
4
+ def apply_binary(img, bin_threshold):
5
+ img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
6
+
7
+ if bin_threshold == 0 or bin_threshold == 255:
8
+ # Otsu's threshold
9
+ otsu_threshold, img_bin = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
10
+ print("Otsu threshold:", otsu_threshold)
11
+ else:
12
+ _, img_bin = cv2.threshold(img_gray, bin_threshold, 255, cv2.THRESH_BINARY_INV)
13
+
14
+ return cv2.cvtColor(img_bin, cv2.COLOR_GRAY2RGB)
annotator/canny/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import cv2
2
+
3
+
4
+ def apply_canny(img, low_threshold, high_threshold):
5
+ return cv2.Canny(img, low_threshold, high_threshold)
annotator/canny/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (433 Bytes). View file
 
annotator/canny/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (327 Bytes). View file
 
annotator/canny/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (344 Bytes). View file
 
annotator/clipvision/__init__.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from modules import devices
5
+ from modules.modelloader import load_file_from_url
6
+ from annotator.annotator_path import models_path
7
+ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor
8
+
9
+
10
+ config_clip_g = {
11
+ "attention_dropout": 0.0,
12
+ "dropout": 0.0,
13
+ "hidden_act": "gelu",
14
+ "hidden_size": 1664,
15
+ "image_size": 224,
16
+ "initializer_factor": 1.0,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 8192,
19
+ "layer_norm_eps": 1e-05,
20
+ "model_type": "clip_vision_model",
21
+ "num_attention_heads": 16,
22
+ "num_channels": 3,
23
+ "num_hidden_layers": 48,
24
+ "patch_size": 14,
25
+ "projection_dim": 1280,
26
+ "torch_dtype": "float32"
27
+ }
28
+
29
+ config_clip_h = {
30
+ "attention_dropout": 0.0,
31
+ "dropout": 0.0,
32
+ "hidden_act": "gelu",
33
+ "hidden_size": 1280,
34
+ "image_size": 224,
35
+ "initializer_factor": 1.0,
36
+ "initializer_range": 0.02,
37
+ "intermediate_size": 5120,
38
+ "layer_norm_eps": 1e-05,
39
+ "model_type": "clip_vision_model",
40
+ "num_attention_heads": 16,
41
+ "num_channels": 3,
42
+ "num_hidden_layers": 32,
43
+ "patch_size": 14,
44
+ "projection_dim": 1024,
45
+ "torch_dtype": "float32"
46
+ }
47
+
48
+ config_clip_vitl = {
49
+ "attention_dropout": 0.0,
50
+ "dropout": 0.0,
51
+ "hidden_act": "quick_gelu",
52
+ "hidden_size": 1024,
53
+ "image_size": 224,
54
+ "initializer_factor": 1.0,
55
+ "initializer_range": 0.02,
56
+ "intermediate_size": 4096,
57
+ "layer_norm_eps": 1e-05,
58
+ "model_type": "clip_vision_model",
59
+ "num_attention_heads": 16,
60
+ "num_channels": 3,
61
+ "num_hidden_layers": 24,
62
+ "patch_size": 14,
63
+ "projection_dim": 768,
64
+ "torch_dtype": "float32"
65
+ }
66
+
67
+ configs = {
68
+ 'clip_g': config_clip_g,
69
+ 'clip_h': config_clip_h,
70
+ 'clip_vitl': config_clip_vitl,
71
+ }
72
+
73
+ downloads = {
74
+ 'clip_vitl': 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin',
75
+ 'clip_g': 'https://huggingface.co/lllyasviel/Annotators/resolve/main/clip_g.pth',
76
+ 'clip_h': 'https://huggingface.co/h94/IP-Adapter/resolve/main/models/image_encoder/pytorch_model.bin'
77
+ }
78
+
79
+
80
+ clip_vision_h_uc = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision_h_uc.data')
81
+ clip_vision_h_uc = torch.load(clip_vision_h_uc)['uc']
82
+
83
+
84
+ class ClipVisionDetector:
85
+ def __init__(self, config):
86
+ assert config in downloads
87
+ self.download_link = downloads[config]
88
+ self.model_path = os.path.join(models_path, 'clip_vision')
89
+ self.file_name = config + '.pth'
90
+ self.config = configs[config]
91
+ self.device = devices.get_device_for("controlnet")
92
+ os.makedirs(self.model_path, exist_ok=True)
93
+ file_path = os.path.join(self.model_path, self.file_name)
94
+ if not os.path.exists(file_path):
95
+ load_file_from_url(url=self.download_link, model_dir=self.model_path, file_name=self.file_name)
96
+ config = CLIPVisionConfig(**self.config)
97
+ self.model = CLIPVisionModelWithProjection(config)
98
+ self.processor = CLIPImageProcessor(crop_size=224,
99
+ do_center_crop=True,
100
+ do_convert_rgb=True,
101
+ do_normalize=True,
102
+ do_resize=True,
103
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
104
+ image_std=[0.26862954, 0.26130258, 0.27577711],
105
+ resample=3,
106
+ size=224)
107
+
108
+ sd = torch.load(file_path, map_location=torch.device('cpu'))
109
+ self.model.load_state_dict(sd, strict=False)
110
+ del sd
111
+
112
+ self.model.eval()
113
+ self.model.cpu()
114
+
115
+ def unload_model(self):
116
+ if self.model is not None:
117
+ self.model.to('meta')
118
+
119
+ def __call__(self, input_image):
120
+ with torch.no_grad():
121
+ clip_vision_model = self.model.cpu()
122
+ feat = self.processor(images=input_image, return_tensors="pt")
123
+ feat['pixel_values'] = feat['pixel_values'].cpu()
124
+ result = clip_vision_model(**feat, output_hidden_states=True)
125
+ result['hidden_states'] = [v.to(devices.get_device_for("controlnet")) for v in result['hidden_states']]
126
+ result = {k: v.to(devices.get_device_for("controlnet")) if isinstance(v, torch.Tensor) else v for k, v in result.items()}
127
+ return result
annotator/clipvision/clip_vision_h_uc.data ADDED
Binary file (659 kB). View file
 
annotator/color/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ def cv2_resize_shortest_edge(image, size):
4
+ h, w = image.shape[:2]
5
+ if h < w:
6
+ new_h = size
7
+ new_w = int(round(w / h * size))
8
+ else:
9
+ new_w = size
10
+ new_h = int(round(h / w * size))
11
+ resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
12
+ return resized_image
13
+
14
+ def apply_color(img, res=512):
15
+ img = cv2_resize_shortest_edge(img, res)
16
+ h, w = img.shape[:2]
17
+
18
+ input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
19
+ input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
20
+ return input_img_color
annotator/hed/__init__.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2
+ # Please use this implementation in your products
3
+ # This implementation may produce slightly different results from Saining Xie's official implementations,
4
+ # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5
+ # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
6
+ # and in this way it works better for gradio's RGB protocol
7
+
8
+ import os
9
+ import cv2
10
+ import torch
11
+ import numpy as np
12
+
13
+ from einops import rearrange
14
+ import os
15
+ from annotator.annotator_path import models_path, DEVICE
16
+ from annotator.util import safe_step, nms
17
+
18
+
19
+ class DoubleConvBlock(torch.nn.Module):
20
+ def __init__(self, input_channel, output_channel, layer_number):
21
+ super().__init__()
22
+ self.convs = torch.nn.Sequential()
23
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
24
+ for i in range(1, layer_number):
25
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
26
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
27
+
28
+ def __call__(self, x, down_sampling=False):
29
+ h = x
30
+ if down_sampling:
31
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
32
+ for conv in self.convs:
33
+ h = conv(h)
34
+ h = torch.nn.functional.relu(h)
35
+ return h, self.projection(h)
36
+
37
+
38
+ class ControlNetHED_Apache2(torch.nn.Module):
39
+ def __init__(self):
40
+ super().__init__()
41
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
42
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
43
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
44
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
45
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
46
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
47
+
48
+ def __call__(self, x):
49
+ h = x - self.norm
50
+ h, projection1 = self.block1(h)
51
+ h, projection2 = self.block2(h, down_sampling=True)
52
+ h, projection3 = self.block3(h, down_sampling=True)
53
+ h, projection4 = self.block4(h, down_sampling=True)
54
+ h, projection5 = self.block5(h, down_sampling=True)
55
+ return projection1, projection2, projection3, projection4, projection5
56
+
57
+
58
+ netNetwork = None
59
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
60
+ modeldir = os.path.join(models_path, "hed")
61
+ old_modeldir = os.path.dirname(os.path.realpath(__file__))
62
+
63
+
64
+ def apply_hed(input_image, is_safe=False):
65
+ global netNetwork
66
+ if netNetwork is None:
67
+ modelpath = os.path.join(modeldir, "ControlNetHED.pth")
68
+ old_modelpath = os.path.join(old_modeldir, "ControlNetHED.pth")
69
+ if os.path.exists(old_modelpath):
70
+ modelpath = old_modelpath
71
+ elif not os.path.exists(modelpath):
72
+ from basicsr.utils.download_util import load_file_from_url
73
+ load_file_from_url(remote_model_path, model_dir=modeldir)
74
+ netNetwork = ControlNetHED_Apache2().to(DEVICE)
75
+ netNetwork.load_state_dict(torch.load(modelpath, map_location='cpu'))
76
+ netNetwork.to(DEVICE).float().eval()
77
+
78
+ assert input_image.ndim == 3
79
+ H, W, C = input_image.shape
80
+ with torch.no_grad():
81
+ image_hed = torch.from_numpy(input_image.copy()).float().to(DEVICE)
82
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
83
+ edges = netNetwork(image_hed)
84
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
85
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
86
+ edges = np.stack(edges, axis=2)
87
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
88
+ if is_safe:
89
+ edge = safe_step(edge)
90
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
91
+ return edge
92
+
93
+
94
+ def unload_hed_model():
95
+ global netNetwork
96
+ if netNetwork is not None:
97
+ netNetwork.cpu()
annotator/hed/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (8.52 kB). View file
 
annotator/hed/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (4.01 kB). View file
 
annotator/hed/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (4.02 kB). View file
 
annotator/keypose/__init__.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import torch
4
+
5
+ import os
6
+ from modules import devices
7
+ from annotator.annotator_path import models_path
8
+
9
+ import mmcv
10
+ from mmdet.apis import inference_detector, init_detector
11
+ from mmpose.apis import inference_top_down_pose_model
12
+ from mmpose.apis import init_pose_model, process_mmdet_results, vis_pose_result
13
+
14
+
15
+ def preprocessing(image, device):
16
+ # Resize
17
+ scale = 640 / max(image.shape[:2])
18
+ image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
19
+ raw_image = image.astype(np.uint8)
20
+
21
+ # Subtract mean values
22
+ image = image.astype(np.float32)
23
+ image -= np.array(
24
+ [
25
+ float(104.008),
26
+ float(116.669),
27
+ float(122.675),
28
+ ]
29
+ )
30
+
31
+ # Convert to torch.Tensor and add "batch" axis
32
+ image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
33
+ image = image.to(device)
34
+
35
+ return image, raw_image
36
+
37
+
38
+ def imshow_keypoints(img,
39
+ pose_result,
40
+ skeleton=None,
41
+ kpt_score_thr=0.1,
42
+ pose_kpt_color=None,
43
+ pose_link_color=None,
44
+ radius=4,
45
+ thickness=1):
46
+ """Draw keypoints and links on an image.
47
+ Args:
48
+ img (ndarry): The image to draw poses on.
49
+ pose_result (list[kpts]): The poses to draw. Each element kpts is
50
+ a set of K keypoints as an Kx3 numpy.ndarray, where each
51
+ keypoint is represented as x, y, score.
52
+ kpt_score_thr (float, optional): Minimum score of keypoints
53
+ to be shown. Default: 0.3.
54
+ pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
55
+ the keypoint will not be drawn.
56
+ pose_link_color (np.array[Mx3]): Color of M links. If None, the
57
+ links will not be drawn.
58
+ thickness (int): Thickness of lines.
59
+ """
60
+
61
+ img_h, img_w, _ = img.shape
62
+ img = np.zeros(img.shape)
63
+
64
+ for idx, kpts in enumerate(pose_result):
65
+ if idx > 1:
66
+ continue
67
+ kpts = kpts['keypoints']
68
+ # print(kpts)
69
+ kpts = np.array(kpts, copy=False)
70
+
71
+ # draw each point on image
72
+ if pose_kpt_color is not None:
73
+ assert len(pose_kpt_color) == len(kpts)
74
+
75
+ for kid, kpt in enumerate(kpts):
76
+ x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
77
+
78
+ if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
79
+ # skip the point that should not be drawn
80
+ continue
81
+
82
+ color = tuple(int(c) for c in pose_kpt_color[kid])
83
+ cv2.circle(img, (int(x_coord), int(y_coord)),
84
+ radius, color, -1)
85
+
86
+ # draw links
87
+ if skeleton is not None and pose_link_color is not None:
88
+ assert len(pose_link_color) == len(skeleton)
89
+
90
+ for sk_id, sk in enumerate(skeleton):
91
+ pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
92
+ pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
93
+
94
+ if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
95
+ or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
96
+ or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
97
+ # skip the link that should not be drawn
98
+ continue
99
+ color = tuple(int(c) for c in pose_link_color[sk_id])
100
+ cv2.line(img, pos1, pos2, color, thickness=thickness)
101
+
102
+ return img
103
+
104
+
105
+ human_det, pose_model = None, None
106
+ det_model_path = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
107
+ pose_model_path = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth"
108
+
109
+ modeldir = os.path.join(models_path, "keypose")
110
+ old_modeldir = os.path.dirname(os.path.realpath(__file__))
111
+
112
+ det_config = 'faster_rcnn_r50_fpn_coco.py'
113
+ pose_config = 'hrnet_w48_coco_256x192.py'
114
+
115
+ det_checkpoint = 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
116
+ pose_checkpoint = 'hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
117
+ det_cat_id = 1
118
+ bbox_thr = 0.2
119
+
120
+ skeleton = [
121
+ [15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8],
122
+ [7, 9], [8, 10],
123
+ [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]
124
+ ]
125
+
126
+ pose_kpt_color = [
127
+ [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
128
+ [0, 255, 0],
129
+ [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0],
130
+ [255, 128, 0],
131
+ [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]
132
+ ]
133
+
134
+ pose_link_color = [
135
+ [0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
136
+ [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
137
+ [255, 128, 0],
138
+ [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
139
+ [51, 153, 255],
140
+ [51, 153, 255], [51, 153, 255], [51, 153, 255]
141
+ ]
142
+
143
+ def find_download_model(checkpoint, remote_path):
144
+ modelpath = os.path.join(modeldir, checkpoint)
145
+ old_modelpath = os.path.join(old_modeldir, checkpoint)
146
+
147
+ if os.path.exists(old_modelpath):
148
+ modelpath = old_modelpath
149
+ elif not os.path.exists(modelpath):
150
+ from basicsr.utils.download_util import load_file_from_url
151
+ load_file_from_url(remote_path, model_dir=modeldir)
152
+
153
+ return modelpath
154
+
155
+ def apply_keypose(input_image):
156
+ global human_det, pose_model
157
+ if netNetwork is None:
158
+ det_model_local = find_download_model(det_checkpoint, det_model_path)
159
+ hrnet_model_local = find_download_model(pose_checkpoint, pose_model_path)
160
+ det_config_mmcv = mmcv.Config.fromfile(det_config)
161
+ pose_config_mmcv = mmcv.Config.fromfile(pose_config)
162
+ human_det = init_detector(det_config_mmcv, det_model_local, device=devices.get_device_for("controlnet"))
163
+ pose_model = init_pose_model(pose_config_mmcv, hrnet_model_local, device=devices.get_device_for("controlnet"))
164
+
165
+ assert input_image.ndim == 3
166
+ input_image = input_image.copy()
167
+ with torch.no_grad():
168
+ image = torch.from_numpy(input_image).float().to(devices.get_device_for("controlnet"))
169
+ image = image / 255.0
170
+ mmdet_results = inference_detector(human_det, image)
171
+
172
+ # keep the person class bounding boxes.
173
+ person_results = process_mmdet_results(mmdet_results, det_cat_id)
174
+
175
+ return_heatmap = False
176
+ dataset = pose_model.cfg.data['test']['type']
177
+
178
+ # e.g. use ('backbone', ) to return backbone feature
179
+ output_layer_names = None
180
+ pose_results, _ = inference_top_down_pose_model(
181
+ pose_model,
182
+ image,
183
+ person_results,
184
+ bbox_thr=bbox_thr,
185
+ format='xyxy',
186
+ dataset=dataset,
187
+ dataset_info=None,
188
+ return_heatmap=return_heatmap,
189
+ outputs=output_layer_names
190
+ )
191
+
192
+ im_keypose_out = imshow_keypoints(
193
+ image,
194
+ pose_results,
195
+ skeleton=skeleton,
196
+ pose_kpt_color=pose_kpt_color,
197
+ pose_link_color=pose_link_color,
198
+ radius=2,
199
+ thickness=2
200
+ )
201
+ im_keypose_out = im_keypose_out.astype(np.uint8)
202
+
203
+ # image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
204
+ # edge = netNetwork(image_hed)[0]
205
+ # edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
206
+ return im_keypose_out
207
+
208
+
209
+ def unload_hed_model():
210
+ global netNetwork
211
+ if netNetwork is not None:
212
+ netNetwork.cpu()
annotator/keypose/faster_rcnn_r50_fpn_coco.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint_config = dict(interval=1)
2
+ # yapf:disable
3
+ log_config = dict(
4
+ interval=50,
5
+ hooks=[
6
+ dict(type='TextLoggerHook'),
7
+ # dict(type='TensorboardLoggerHook')
8
+ ])
9
+ # yapf:enable
10
+ dist_params = dict(backend='nccl')
11
+ log_level = 'INFO'
12
+ load_from = None
13
+ resume_from = None
14
+ workflow = [('train', 1)]
15
+ # optimizer
16
+ optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
17
+ optimizer_config = dict(grad_clip=None)
18
+ # learning policy
19
+ lr_config = dict(
20
+ policy='step',
21
+ warmup='linear',
22
+ warmup_iters=500,
23
+ warmup_ratio=0.001,
24
+ step=[8, 11])
25
+ total_epochs = 12
26
+
27
+ model = dict(
28
+ type='FasterRCNN',
29
+ pretrained='torchvision://resnet50',
30
+ backbone=dict(
31
+ type='ResNet',
32
+ depth=50,
33
+ num_stages=4,
34
+ out_indices=(0, 1, 2, 3),
35
+ frozen_stages=1,
36
+ norm_cfg=dict(type='BN', requires_grad=True),
37
+ norm_eval=True,
38
+ style='pytorch'),
39
+ neck=dict(
40
+ type='FPN',
41
+ in_channels=[256, 512, 1024, 2048],
42
+ out_channels=256,
43
+ num_outs=5),
44
+ rpn_head=dict(
45
+ type='RPNHead',
46
+ in_channels=256,
47
+ feat_channels=256,
48
+ anchor_generator=dict(
49
+ type='AnchorGenerator',
50
+ scales=[8],
51
+ ratios=[0.5, 1.0, 2.0],
52
+ strides=[4, 8, 16, 32, 64]),
53
+ bbox_coder=dict(
54
+ type='DeltaXYWHBBoxCoder',
55
+ target_means=[.0, .0, .0, .0],
56
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
57
+ loss_cls=dict(
58
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
59
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
60
+ roi_head=dict(
61
+ type='StandardRoIHead',
62
+ bbox_roi_extractor=dict(
63
+ type='SingleRoIExtractor',
64
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
65
+ out_channels=256,
66
+ featmap_strides=[4, 8, 16, 32]),
67
+ bbox_head=dict(
68
+ type='Shared2FCBBoxHead',
69
+ in_channels=256,
70
+ fc_out_channels=1024,
71
+ roi_feat_size=7,
72
+ num_classes=80,
73
+ bbox_coder=dict(
74
+ type='DeltaXYWHBBoxCoder',
75
+ target_means=[0., 0., 0., 0.],
76
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
77
+ reg_class_agnostic=False,
78
+ loss_cls=dict(
79
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
80
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
81
+ # model training and testing settings
82
+ train_cfg=dict(
83
+ rpn=dict(
84
+ assigner=dict(
85
+ type='MaxIoUAssigner',
86
+ pos_iou_thr=0.7,
87
+ neg_iou_thr=0.3,
88
+ min_pos_iou=0.3,
89
+ match_low_quality=True,
90
+ ignore_iof_thr=-1),
91
+ sampler=dict(
92
+ type='RandomSampler',
93
+ num=256,
94
+ pos_fraction=0.5,
95
+ neg_pos_ub=-1,
96
+ add_gt_as_proposals=False),
97
+ allowed_border=-1,
98
+ pos_weight=-1,
99
+ debug=False),
100
+ rpn_proposal=dict(
101
+ nms_pre=2000,
102
+ max_per_img=1000,
103
+ nms=dict(type='nms', iou_threshold=0.7),
104
+ min_bbox_size=0),
105
+ rcnn=dict(
106
+ assigner=dict(
107
+ type='MaxIoUAssigner',
108
+ pos_iou_thr=0.5,
109
+ neg_iou_thr=0.5,
110
+ min_pos_iou=0.5,
111
+ match_low_quality=False,
112
+ ignore_iof_thr=-1),
113
+ sampler=dict(
114
+ type='RandomSampler',
115
+ num=512,
116
+ pos_fraction=0.25,
117
+ neg_pos_ub=-1,
118
+ add_gt_as_proposals=True),
119
+ pos_weight=-1,
120
+ debug=False)),
121
+ test_cfg=dict(
122
+ rpn=dict(
123
+ nms_pre=1000,
124
+ max_per_img=1000,
125
+ nms=dict(type='nms', iou_threshold=0.7),
126
+ min_bbox_size=0),
127
+ rcnn=dict(
128
+ score_thr=0.05,
129
+ nms=dict(type='nms', iou_threshold=0.5),
130
+ max_per_img=100)
131
+ # soft-nms is also supported for rcnn testing
132
+ # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
133
+ ))
134
+
135
+ dataset_type = 'CocoDataset'
136
+ data_root = 'data/coco'
137
+ img_norm_cfg = dict(
138
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
139
+ train_pipeline = [
140
+ dict(type='LoadImageFromFile'),
141
+ dict(type='LoadAnnotations', with_bbox=True),
142
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
143
+ dict(type='RandomFlip', flip_ratio=0.5),
144
+ dict(type='Normalize', **img_norm_cfg),
145
+ dict(type='Pad', size_divisor=32),
146
+ dict(type='DefaultFormatBundle'),
147
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
148
+ ]
149
+ test_pipeline = [
150
+ dict(type='LoadImageFromFile'),
151
+ dict(
152
+ type='MultiScaleFlipAug',
153
+ img_scale=(1333, 800),
154
+ flip=False,
155
+ transforms=[
156
+ dict(type='Resize', keep_ratio=True),
157
+ dict(type='RandomFlip'),
158
+ dict(type='Normalize', **img_norm_cfg),
159
+ dict(type='Pad', size_divisor=32),
160
+ dict(type='DefaultFormatBundle'),
161
+ dict(type='Collect', keys=['img']),
162
+ ])
163
+ ]
164
+ data = dict(
165
+ samples_per_gpu=2,
166
+ workers_per_gpu=2,
167
+ train=dict(
168
+ type=dataset_type,
169
+ ann_file=f'{data_root}/annotations/instances_train2017.json',
170
+ img_prefix=f'{data_root}/train2017/',
171
+ pipeline=train_pipeline),
172
+ val=dict(
173
+ type=dataset_type,
174
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
175
+ img_prefix=f'{data_root}/val2017/',
176
+ pipeline=test_pipeline),
177
+ test=dict(
178
+ type=dataset_type,
179
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
180
+ img_prefix=f'{data_root}/val2017/',
181
+ pipeline=test_pipeline))
182
+ evaluation = dict(interval=1, metric='bbox')
annotator/keypose/hrnet_w48_coco_256x192.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # _base_ = [
2
+ # '../../../../_base_/default_runtime.py',
3
+ # '../../../../_base_/datasets/coco.py'
4
+ # ]
5
+ evaluation = dict(interval=10, metric='mAP', save_best='AP')
6
+
7
+ optimizer = dict(
8
+ type='Adam',
9
+ lr=5e-4,
10
+ )
11
+ optimizer_config = dict(grad_clip=None)
12
+ # learning policy
13
+ lr_config = dict(
14
+ policy='step',
15
+ warmup='linear',
16
+ warmup_iters=500,
17
+ warmup_ratio=0.001,
18
+ step=[170, 200])
19
+ total_epochs = 210
20
+ channel_cfg = dict(
21
+ num_output_channels=17,
22
+ dataset_joints=17,
23
+ dataset_channel=[
24
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
25
+ ],
26
+ inference_channel=[
27
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
28
+ ])
29
+
30
+ # model settings
31
+ model = dict(
32
+ type='TopDown',
33
+ pretrained='https://download.openmmlab.com/mmpose/'
34
+ 'pretrain_models/hrnet_w48-8ef0771d.pth',
35
+ backbone=dict(
36
+ type='HRNet',
37
+ in_channels=3,
38
+ extra=dict(
39
+ stage1=dict(
40
+ num_modules=1,
41
+ num_branches=1,
42
+ block='BOTTLENECK',
43
+ num_blocks=(4, ),
44
+ num_channels=(64, )),
45
+ stage2=dict(
46
+ num_modules=1,
47
+ num_branches=2,
48
+ block='BASIC',
49
+ num_blocks=(4, 4),
50
+ num_channels=(48, 96)),
51
+ stage3=dict(
52
+ num_modules=4,
53
+ num_branches=3,
54
+ block='BASIC',
55
+ num_blocks=(4, 4, 4),
56
+ num_channels=(48, 96, 192)),
57
+ stage4=dict(
58
+ num_modules=3,
59
+ num_branches=4,
60
+ block='BASIC',
61
+ num_blocks=(4, 4, 4, 4),
62
+ num_channels=(48, 96, 192, 384))),
63
+ ),
64
+ keypoint_head=dict(
65
+ type='TopdownHeatmapSimpleHead',
66
+ in_channels=48,
67
+ out_channels=channel_cfg['num_output_channels'],
68
+ num_deconv_layers=0,
69
+ extra=dict(final_conv_kernel=1, ),
70
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
71
+ train_cfg=dict(),
72
+ test_cfg=dict(
73
+ flip_test=True,
74
+ post_process='default',
75
+ shift_heatmap=True,
76
+ modulate_kernel=11))
77
+
78
+ data_cfg = dict(
79
+ image_size=[192, 256],
80
+ heatmap_size=[48, 64],
81
+ num_output_channels=channel_cfg['num_output_channels'],
82
+ num_joints=channel_cfg['dataset_joints'],
83
+ dataset_channel=channel_cfg['dataset_channel'],
84
+ inference_channel=channel_cfg['inference_channel'],
85
+ soft_nms=False,
86
+ nms_thr=1.0,
87
+ oks_thr=0.9,
88
+ vis_thr=0.2,
89
+ use_gt_bbox=False,
90
+ det_bbox_thr=0.0,
91
+ bbox_file='data/coco/person_detection_results/'
92
+ 'COCO_val2017_detections_AP_H_56_person.json',
93
+ )
94
+
95
+ train_pipeline = [
96
+ dict(type='LoadImageFromFile'),
97
+ dict(type='TopDownGetBboxCenterScale', padding=1.25),
98
+ dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3),
99
+ dict(type='TopDownRandomFlip', flip_prob=0.5),
100
+ dict(
101
+ type='TopDownHalfBodyTransform',
102
+ num_joints_half_body=8,
103
+ prob_half_body=0.3),
104
+ dict(
105
+ type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
106
+ dict(type='TopDownAffine'),
107
+ dict(type='ToTensor'),
108
+ dict(
109
+ type='NormalizeTensor',
110
+ mean=[0.485, 0.456, 0.406],
111
+ std=[0.229, 0.224, 0.225]),
112
+ dict(type='TopDownGenerateTarget', sigma=2),
113
+ dict(
114
+ type='Collect',
115
+ keys=['img', 'target', 'target_weight'],
116
+ meta_keys=[
117
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
118
+ 'rotation', 'bbox_score', 'flip_pairs'
119
+ ]),
120
+ ]
121
+
122
+ val_pipeline = [
123
+ dict(type='LoadImageFromFile'),
124
+ dict(type='TopDownGetBboxCenterScale', padding=1.25),
125
+ dict(type='TopDownAffine'),
126
+ dict(type='ToTensor'),
127
+ dict(
128
+ type='NormalizeTensor',
129
+ mean=[0.485, 0.456, 0.406],
130
+ std=[0.229, 0.224, 0.225]),
131
+ dict(
132
+ type='Collect',
133
+ keys=['img'],
134
+ meta_keys=[
135
+ 'image_file', 'center', 'scale', 'rotation', 'bbox_score',
136
+ 'flip_pairs'
137
+ ]),
138
+ ]
139
+
140
+ test_pipeline = val_pipeline
141
+
142
+ data_root = 'data/coco'
143
+ data = dict(
144
+ samples_per_gpu=32,
145
+ workers_per_gpu=2,
146
+ val_dataloader=dict(samples_per_gpu=32),
147
+ test_dataloader=dict(samples_per_gpu=32),
148
+ train=dict(
149
+ type='TopDownCocoDataset',
150
+ ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
151
+ img_prefix=f'{data_root}/train2017/',
152
+ data_cfg=data_cfg,
153
+ pipeline=train_pipeline,
154
+ dataset_info={{_base_.dataset_info}}),
155
+ val=dict(
156
+ type='TopDownCocoDataset',
157
+ ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
158
+ img_prefix=f'{data_root}/val2017/',
159
+ data_cfg=data_cfg,
160
+ pipeline=val_pipeline,
161
+ dataset_info={{_base_.dataset_info}}),
162
+ test=dict(
163
+ type='TopDownCocoDataset',
164
+ ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
165
+ img_prefix=f'{data_root}/val2017/',
166
+ data_cfg=data_cfg,
167
+ pipeline=test_pipeline,
168
+ dataset_info={{_base_.dataset_info}}),
169
+ )
annotator/lama/__init__.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/advimman/lama
2
+
3
+ import yaml
4
+ import torch
5
+ from omegaconf import OmegaConf
6
+ import numpy as np
7
+
8
+ from einops import rearrange
9
+ import os
10
+ from modules import devices
11
+ from annotator.annotator_path import models_path
12
+ from annotator.lama.saicinpainting.training.trainers import load_checkpoint
13
+
14
+
15
+ class LamaInpainting:
16
+ model_dir = os.path.join(models_path, "lama")
17
+
18
+ def __init__(self):
19
+ self.model = None
20
+ self.device = devices.get_device_for("controlnet")
21
+
22
+ def load_model(self):
23
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetLama.pth"
24
+ modelpath = os.path.join(self.model_dir, "ControlNetLama.pth")
25
+ if not os.path.exists(modelpath):
26
+ from basicsr.utils.download_util import load_file_from_url
27
+ load_file_from_url(remote_model_path, model_dir=self.model_dir)
28
+ config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.yaml')
29
+ cfg = yaml.safe_load(open(config_path, 'rt'))
30
+ cfg = OmegaConf.create(cfg)
31
+ cfg.training_model.predict_only = True
32
+ cfg.visualizer.kind = 'noop'
33
+ self.model = load_checkpoint(cfg, os.path.abspath(modelpath), strict=False, map_location='cpu')
34
+ self.model = self.model.to(self.device)
35
+ self.model.eval()
36
+
37
+ def unload_model(self):
38
+ if self.model is not None:
39
+ self.model.cpu()
40
+
41
+ def __call__(self, input_image):
42
+ if self.model is None:
43
+ self.load_model()
44
+ self.model.to(self.device)
45
+ color = np.ascontiguousarray(input_image[:, :, 0:3]).astype(np.float32) / 255.0
46
+ mask = np.ascontiguousarray(input_image[:, :, 3:4]).astype(np.float32) / 255.0
47
+ with torch.no_grad():
48
+ color = torch.from_numpy(color).float().to(self.device)
49
+ mask = torch.from_numpy(mask).float().to(self.device)
50
+ mask = (mask > 0.5).float()
51
+ color = color * (1 - mask)
52
+ image_feed = torch.cat([color, mask], dim=2)
53
+ image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
54
+ result = self.model(image_feed)[0]
55
+ result = rearrange(result, 'c h w -> h w c')
56
+ result = result * mask + color * (1 - mask)
57
+ result *= 255.0
58
+ return result.detach().cpu().numpy().clip(0, 255).astype(np.uint8)
annotator/lama/config.yaml ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_title: b18_ffc075_batch8x15
2
+ training_model:
3
+ kind: default
4
+ visualize_each_iters: 1000
5
+ concat_mask: true
6
+ store_discr_outputs_for_vis: true
7
+ losses:
8
+ l1:
9
+ weight_missing: 0
10
+ weight_known: 10
11
+ perceptual:
12
+ weight: 0
13
+ adversarial:
14
+ kind: r1
15
+ weight: 10
16
+ gp_coef: 0.001
17
+ mask_as_fake_target: true
18
+ allow_scale_mask: true
19
+ feature_matching:
20
+ weight: 100
21
+ resnet_pl:
22
+ weight: 30
23
+ weights_path: ${env:TORCH_HOME}
24
+
25
+ optimizers:
26
+ generator:
27
+ kind: adam
28
+ lr: 0.001
29
+ discriminator:
30
+ kind: adam
31
+ lr: 0.0001
32
+ visualizer:
33
+ key_order:
34
+ - image
35
+ - predicted_image
36
+ - discr_output_fake
37
+ - discr_output_real
38
+ - inpainted
39
+ rescale_keys:
40
+ - discr_output_fake
41
+ - discr_output_real
42
+ kind: directory
43
+ outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples
44
+ location:
45
+ data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large
46
+ out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments
47
+ tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs
48
+ data:
49
+ batch_size: 15
50
+ val_batch_size: 2
51
+ num_workers: 3
52
+ train:
53
+ indir: ${location.data_root_dir}/train
54
+ out_size: 256
55
+ mask_gen_kwargs:
56
+ irregular_proba: 1
57
+ irregular_kwargs:
58
+ max_angle: 4
59
+ max_len: 200
60
+ max_width: 100
61
+ max_times: 5
62
+ min_times: 1
63
+ box_proba: 1
64
+ box_kwargs:
65
+ margin: 10
66
+ bbox_min_size: 30
67
+ bbox_max_size: 150
68
+ max_times: 3
69
+ min_times: 1
70
+ segm_proba: 0
71
+ segm_kwargs:
72
+ confidence_threshold: 0.5
73
+ max_object_area: 0.5
74
+ min_mask_area: 0.07
75
+ downsample_levels: 6
76
+ num_variants_per_mask: 1
77
+ rigidness_mode: 1
78
+ max_foreground_coverage: 0.3
79
+ max_foreground_intersection: 0.7
80
+ max_mask_intersection: 0.1
81
+ max_hidden_area: 0.1
82
+ max_scale_change: 0.25
83
+ horizontal_flip: true
84
+ max_vertical_shift: 0.2
85
+ position_shuffle: true
86
+ transform_variant: distortions
87
+ dataloader_kwargs:
88
+ batch_size: ${data.batch_size}
89
+ shuffle: true
90
+ num_workers: ${data.num_workers}
91
+ val:
92
+ indir: ${location.data_root_dir}/val
93
+ img_suffix: .png
94
+ dataloader_kwargs:
95
+ batch_size: ${data.val_batch_size}
96
+ shuffle: false
97
+ num_workers: ${data.num_workers}
98
+ visual_test:
99
+ indir: ${location.data_root_dir}/korean_test
100
+ img_suffix: _input.png
101
+ pad_out_to_modulo: 32
102
+ dataloader_kwargs:
103
+ batch_size: 1
104
+ shuffle: false
105
+ num_workers: ${data.num_workers}
106
+ generator:
107
+ kind: ffc_resnet
108
+ input_nc: 4
109
+ output_nc: 3
110
+ ngf: 64
111
+ n_downsampling: 3
112
+ n_blocks: 18
113
+ add_out_act: sigmoid
114
+ init_conv_kwargs:
115
+ ratio_gin: 0
116
+ ratio_gout: 0
117
+ enable_lfu: false
118
+ downsample_conv_kwargs:
119
+ ratio_gin: ${generator.init_conv_kwargs.ratio_gout}
120
+ ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin}
121
+ enable_lfu: false
122
+ resnet_conv_kwargs:
123
+ ratio_gin: 0.75
124
+ ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin}
125
+ enable_lfu: false
126
+ discriminator:
127
+ kind: pix2pixhd_nlayer
128
+ input_nc: 3
129
+ ndf: 64
130
+ n_layers: 4
131
+ evaluator:
132
+ kind: default
133
+ inpainted_key: inpainted
134
+ integral_kind: ssim_fid100_f1
135
+ trainer:
136
+ kwargs:
137
+ gpus: -1
138
+ accelerator: ddp
139
+ max_epochs: 200
140
+ gradient_clip_val: 1
141
+ log_gpu_memory: None
142
+ limit_train_batches: 25000
143
+ val_check_interval: ${trainer.kwargs.limit_train_batches}
144
+ log_every_n_steps: 1000
145
+ precision: 32
146
+ terminate_on_nan: false
147
+ check_val_every_n_epoch: 1
148
+ num_sanity_val_steps: 8
149
+ limit_val_batches: 1000
150
+ replace_sampler_ddp: false
151
+ checkpoint_kwargs:
152
+ verbose: true
153
+ save_top_k: 5
154
+ save_last: true
155
+ period: 1
156
+ monitor: val_ssim_fid100_f1_total_mean
157
+ mode: max
annotator/lama/saicinpainting/__init__.py ADDED
File without changes
annotator/lama/saicinpainting/training/__init__.py ADDED
File without changes
annotator/lama/saicinpainting/training/data/__init__.py ADDED
File without changes
annotator/lama/saicinpainting/training/data/masks.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import hashlib
4
+ import logging
5
+ from enum import Enum
6
+
7
+ import cv2
8
+ import numpy as np
9
+
10
+ # from annotator.lama.saicinpainting.evaluation.masks.mask import SegmentationMask
11
+ from annotator.lama.saicinpainting.utils import LinearRamp
12
+
13
+ LOGGER = logging.getLogger(__name__)
14
+
15
+
16
+ class DrawMethod(Enum):
17
+ LINE = 'line'
18
+ CIRCLE = 'circle'
19
+ SQUARE = 'square'
20
+
21
+
22
+ def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10,
23
+ draw_method=DrawMethod.LINE):
24
+ draw_method = DrawMethod(draw_method)
25
+
26
+ height, width = shape
27
+ mask = np.zeros((height, width), np.float32)
28
+ times = np.random.randint(min_times, max_times + 1)
29
+ for i in range(times):
30
+ start_x = np.random.randint(width)
31
+ start_y = np.random.randint(height)
32
+ for j in range(1 + np.random.randint(5)):
33
+ angle = 0.01 + np.random.randint(max_angle)
34
+ if i % 2 == 0:
35
+ angle = 2 * 3.1415926 - angle
36
+ length = 10 + np.random.randint(max_len)
37
+ brush_w = 5 + np.random.randint(max_width)
38
+ end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
39
+ end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
40
+ if draw_method == DrawMethod.LINE:
41
+ cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
42
+ elif draw_method == DrawMethod.CIRCLE:
43
+ cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1)
44
+ elif draw_method == DrawMethod.SQUARE:
45
+ radius = brush_w // 2
46
+ mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1
47
+ start_x, start_y = end_x, end_y
48
+ return mask[None, ...]
49
+
50
+
51
+ class RandomIrregularMaskGenerator:
52
+ def __init__(self, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, ramp_kwargs=None,
53
+ draw_method=DrawMethod.LINE):
54
+ self.max_angle = max_angle
55
+ self.max_len = max_len
56
+ self.max_width = max_width
57
+ self.min_times = min_times
58
+ self.max_times = max_times
59
+ self.draw_method = draw_method
60
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
61
+
62
+ def __call__(self, img, iter_i=None, raw_image=None):
63
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
64
+ cur_max_len = int(max(1, self.max_len * coef))
65
+ cur_max_width = int(max(1, self.max_width * coef))
66
+ cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef)
67
+ return make_random_irregular_mask(img.shape[1:], max_angle=self.max_angle, max_len=cur_max_len,
68
+ max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times,
69
+ draw_method=self.draw_method)
70
+
71
+
72
+ def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3):
73
+ height, width = shape
74
+ mask = np.zeros((height, width), np.float32)
75
+ bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
76
+ times = np.random.randint(min_times, max_times + 1)
77
+ for i in range(times):
78
+ box_width = np.random.randint(bbox_min_size, bbox_max_size)
79
+ box_height = np.random.randint(bbox_min_size, bbox_max_size)
80
+ start_x = np.random.randint(margin, width - margin - box_width + 1)
81
+ start_y = np.random.randint(margin, height - margin - box_height + 1)
82
+ mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
83
+ return mask[None, ...]
84
+
85
+
86
+ class RandomRectangleMaskGenerator:
87
+ def __init__(self, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3, ramp_kwargs=None):
88
+ self.margin = margin
89
+ self.bbox_min_size = bbox_min_size
90
+ self.bbox_max_size = bbox_max_size
91
+ self.min_times = min_times
92
+ self.max_times = max_times
93
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
94
+
95
+ def __call__(self, img, iter_i=None, raw_image=None):
96
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
97
+ cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef)
98
+ cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef)
99
+ return make_random_rectangle_mask(img.shape[1:], margin=self.margin, bbox_min_size=self.bbox_min_size,
100
+ bbox_max_size=cur_bbox_max_size, min_times=self.min_times,
101
+ max_times=cur_max_times)
102
+
103
+
104
+ class RandomSegmentationMaskGenerator:
105
+ def __init__(self, **kwargs):
106
+ self.impl = None # will be instantiated in first call (effectively in subprocess)
107
+ self.kwargs = kwargs
108
+
109
+ def __call__(self, img, iter_i=None, raw_image=None):
110
+ if self.impl is None:
111
+ self.impl = SegmentationMask(**self.kwargs)
112
+
113
+ masks = self.impl.get_masks(np.transpose(img, (1, 2, 0)))
114
+ masks = [m for m in masks if len(np.unique(m)) > 1]
115
+ return np.random.choice(masks)
116
+
117
+
118
+ def make_random_superres_mask(shape, min_step=2, max_step=4, min_width=1, max_width=3):
119
+ height, width = shape
120
+ mask = np.zeros((height, width), np.float32)
121
+ step_x = np.random.randint(min_step, max_step + 1)
122
+ width_x = np.random.randint(min_width, min(step_x, max_width + 1))
123
+ offset_x = np.random.randint(0, step_x)
124
+
125
+ step_y = np.random.randint(min_step, max_step + 1)
126
+ width_y = np.random.randint(min_width, min(step_y, max_width + 1))
127
+ offset_y = np.random.randint(0, step_y)
128
+
129
+ for dy in range(width_y):
130
+ mask[offset_y + dy::step_y] = 1
131
+ for dx in range(width_x):
132
+ mask[:, offset_x + dx::step_x] = 1
133
+ return mask[None, ...]
134
+
135
+
136
+ class RandomSuperresMaskGenerator:
137
+ def __init__(self, **kwargs):
138
+ self.kwargs = kwargs
139
+
140
+ def __call__(self, img, iter_i=None):
141
+ return make_random_superres_mask(img.shape[1:], **self.kwargs)
142
+
143
+
144
+ class DumbAreaMaskGenerator:
145
+ min_ratio = 0.1
146
+ max_ratio = 0.35
147
+ default_ratio = 0.225
148
+
149
+ def __init__(self, is_training):
150
+ #Parameters:
151
+ # is_training(bool): If true - random rectangular mask, if false - central square mask
152
+ self.is_training = is_training
153
+
154
+ def _random_vector(self, dimension):
155
+ if self.is_training:
156
+ lower_limit = math.sqrt(self.min_ratio)
157
+ upper_limit = math.sqrt(self.max_ratio)
158
+ mask_side = round((random.random() * (upper_limit - lower_limit) + lower_limit) * dimension)
159
+ u = random.randint(0, dimension-mask_side-1)
160
+ v = u+mask_side
161
+ else:
162
+ margin = (math.sqrt(self.default_ratio) / 2) * dimension
163
+ u = round(dimension/2 - margin)
164
+ v = round(dimension/2 + margin)
165
+ return u, v
166
+
167
+ def __call__(self, img, iter_i=None, raw_image=None):
168
+ c, height, width = img.shape
169
+ mask = np.zeros((height, width), np.float32)
170
+ x1, x2 = self._random_vector(width)
171
+ y1, y2 = self._random_vector(height)
172
+ mask[x1:x2, y1:y2] = 1
173
+ return mask[None, ...]
174
+
175
+
176
+ class OutpaintingMaskGenerator:
177
+ def __init__(self, min_padding_percent:float=0.04, max_padding_percent:int=0.25, left_padding_prob:float=0.5, top_padding_prob:float=0.5,
178
+ right_padding_prob:float=0.5, bottom_padding_prob:float=0.5, is_fixed_randomness:bool=False):
179
+ """
180
+ is_fixed_randomness - get identical paddings for the same image if args are the same
181
+ """
182
+ self.min_padding_percent = min_padding_percent
183
+ self.max_padding_percent = max_padding_percent
184
+ self.probs = [left_padding_prob, top_padding_prob, right_padding_prob, bottom_padding_prob]
185
+ self.is_fixed_randomness = is_fixed_randomness
186
+
187
+ assert self.min_padding_percent <= self.max_padding_percent
188
+ assert self.max_padding_percent > 0
189
+ assert len([x for x in [self.min_padding_percent, self.max_padding_percent] if (x>=0 and x<=1)]) == 2, f"Padding percentage should be in [0,1]"
190
+ assert sum(self.probs) > 0, f"At least one of the padding probs should be greater than 0 - {self.probs}"
191
+ assert len([x for x in self.probs if (x >= 0) and (x <= 1)]) == 4, f"At least one of padding probs is not in [0,1] - {self.probs}"
192
+ if len([x for x in self.probs if x > 0]) == 1:
193
+ LOGGER.warning(f"Only one padding prob is greater than zero - {self.probs}. That means that the outpainting masks will be always on the same side")
194
+
195
+ def apply_padding(self, mask, coord):
196
+ mask[int(coord[0][0]*self.img_h):int(coord[1][0]*self.img_h),
197
+ int(coord[0][1]*self.img_w):int(coord[1][1]*self.img_w)] = 1
198
+ return mask
199
+
200
+ def get_padding(self, size):
201
+ n1 = int(self.min_padding_percent*size)
202
+ n2 = int(self.max_padding_percent*size)
203
+ return self.rnd.randint(n1, n2) / size
204
+
205
+ @staticmethod
206
+ def _img2rs(img):
207
+ arr = np.ascontiguousarray(img.astype(np.uint8))
208
+ str_hash = hashlib.sha1(arr).hexdigest()
209
+ res = hash(str_hash)%(2**32)
210
+ return res
211
+
212
+ def __call__(self, img, iter_i=None, raw_image=None):
213
+ c, self.img_h, self.img_w = img.shape
214
+ mask = np.zeros((self.img_h, self.img_w), np.float32)
215
+ at_least_one_mask_applied = False
216
+
217
+ if self.is_fixed_randomness:
218
+ assert raw_image is not None, f"Cant calculate hash on raw_image=None"
219
+ rs = self._img2rs(raw_image)
220
+ self.rnd = np.random.RandomState(rs)
221
+ else:
222
+ self.rnd = np.random
223
+
224
+ coords = [[
225
+ (0,0),
226
+ (1,self.get_padding(size=self.img_h))
227
+ ],
228
+ [
229
+ (0,0),
230
+ (self.get_padding(size=self.img_w),1)
231
+ ],
232
+ [
233
+ (0,1-self.get_padding(size=self.img_h)),
234
+ (1,1)
235
+ ],
236
+ [
237
+ (1-self.get_padding(size=self.img_w),0),
238
+ (1,1)
239
+ ]]
240
+
241
+ for pp, coord in zip(self.probs, coords):
242
+ if self.rnd.random() < pp:
243
+ at_least_one_mask_applied = True
244
+ mask = self.apply_padding(mask=mask, coord=coord)
245
+
246
+ if not at_least_one_mask_applied:
247
+ idx = self.rnd.choice(range(len(coords)), p=np.array(self.probs)/sum(self.probs))
248
+ mask = self.apply_padding(mask=mask, coord=coords[idx])
249
+ return mask[None, ...]
250
+
251
+
252
+ class MixedMaskGenerator:
253
+ def __init__(self, irregular_proba=1/3, irregular_kwargs=None,
254
+ box_proba=1/3, box_kwargs=None,
255
+ segm_proba=1/3, segm_kwargs=None,
256
+ squares_proba=0, squares_kwargs=None,
257
+ superres_proba=0, superres_kwargs=None,
258
+ outpainting_proba=0, outpainting_kwargs=None,
259
+ invert_proba=0):
260
+ self.probas = []
261
+ self.gens = []
262
+
263
+ if irregular_proba > 0:
264
+ self.probas.append(irregular_proba)
265
+ if irregular_kwargs is None:
266
+ irregular_kwargs = {}
267
+ else:
268
+ irregular_kwargs = dict(irregular_kwargs)
269
+ irregular_kwargs['draw_method'] = DrawMethod.LINE
270
+ self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs))
271
+
272
+ if box_proba > 0:
273
+ self.probas.append(box_proba)
274
+ if box_kwargs is None:
275
+ box_kwargs = {}
276
+ self.gens.append(RandomRectangleMaskGenerator(**box_kwargs))
277
+
278
+ if segm_proba > 0:
279
+ self.probas.append(segm_proba)
280
+ if segm_kwargs is None:
281
+ segm_kwargs = {}
282
+ self.gens.append(RandomSegmentationMaskGenerator(**segm_kwargs))
283
+
284
+ if squares_proba > 0:
285
+ self.probas.append(squares_proba)
286
+ if squares_kwargs is None:
287
+ squares_kwargs = {}
288
+ else:
289
+ squares_kwargs = dict(squares_kwargs)
290
+ squares_kwargs['draw_method'] = DrawMethod.SQUARE
291
+ self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs))
292
+
293
+ if superres_proba > 0:
294
+ self.probas.append(superres_proba)
295
+ if superres_kwargs is None:
296
+ superres_kwargs = {}
297
+ self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs))
298
+
299
+ if outpainting_proba > 0:
300
+ self.probas.append(outpainting_proba)
301
+ if outpainting_kwargs is None:
302
+ outpainting_kwargs = {}
303
+ self.gens.append(OutpaintingMaskGenerator(**outpainting_kwargs))
304
+
305
+ self.probas = np.array(self.probas, dtype='float32')
306
+ self.probas /= self.probas.sum()
307
+ self.invert_proba = invert_proba
308
+
309
+ def __call__(self, img, iter_i=None, raw_image=None):
310
+ kind = np.random.choice(len(self.probas), p=self.probas)
311
+ gen = self.gens[kind]
312
+ result = gen(img, iter_i=iter_i, raw_image=raw_image)
313
+ if self.invert_proba > 0 and random.random() < self.invert_proba:
314
+ result = 1 - result
315
+ return result
316
+
317
+
318
+ def get_mask_generator(kind, kwargs):
319
+ if kind is None:
320
+ kind = "mixed"
321
+ if kwargs is None:
322
+ kwargs = {}
323
+
324
+ if kind == "mixed":
325
+ cl = MixedMaskGenerator
326
+ elif kind == "outpainting":
327
+ cl = OutpaintingMaskGenerator
328
+ elif kind == "dumb":
329
+ cl = DumbAreaMaskGenerator
330
+ else:
331
+ raise NotImplementedError(f"No such generator kind = {kind}")
332
+ return cl(**kwargs)
annotator/lama/saicinpainting/training/losses/__init__.py ADDED
File without changes
annotator/lama/saicinpainting/training/losses/adversarial.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Dict, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class BaseAdversarialLoss:
9
+ def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
10
+ generator: nn.Module, discriminator: nn.Module):
11
+ """
12
+ Prepare for generator step
13
+ :param real_batch: Tensor, a batch of real samples
14
+ :param fake_batch: Tensor, a batch of samples produced by generator
15
+ :param generator:
16
+ :param discriminator:
17
+ :return: None
18
+ """
19
+
20
+ def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
21
+ generator: nn.Module, discriminator: nn.Module):
22
+ """
23
+ Prepare for discriminator step
24
+ :param real_batch: Tensor, a batch of real samples
25
+ :param fake_batch: Tensor, a batch of samples produced by generator
26
+ :param generator:
27
+ :param discriminator:
28
+ :return: None
29
+ """
30
+
31
+ def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
32
+ discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
33
+ mask: Optional[torch.Tensor] = None) \
34
+ -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
35
+ """
36
+ Calculate generator loss
37
+ :param real_batch: Tensor, a batch of real samples
38
+ :param fake_batch: Tensor, a batch of samples produced by generator
39
+ :param discr_real_pred: Tensor, discriminator output for real_batch
40
+ :param discr_fake_pred: Tensor, discriminator output for fake_batch
41
+ :param mask: Tensor, actual mask, which was at input of generator when making fake_batch
42
+ :return: total generator loss along with some values that might be interesting to log
43
+ """
44
+ raise NotImplemented()
45
+
46
+ def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
47
+ discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
48
+ mask: Optional[torch.Tensor] = None) \
49
+ -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
50
+ """
51
+ Calculate discriminator loss and call .backward() on it
52
+ :param real_batch: Tensor, a batch of real samples
53
+ :param fake_batch: Tensor, a batch of samples produced by generator
54
+ :param discr_real_pred: Tensor, discriminator output for real_batch
55
+ :param discr_fake_pred: Tensor, discriminator output for fake_batch
56
+ :param mask: Tensor, actual mask, which was at input of generator when making fake_batch
57
+ :return: total discriminator loss along with some values that might be interesting to log
58
+ """
59
+ raise NotImplemented()
60
+
61
+ def interpolate_mask(self, mask, shape):
62
+ assert mask is not None
63
+ assert self.allow_scale_mask or shape == mask.shape[-2:]
64
+ if shape != mask.shape[-2:] and self.allow_scale_mask:
65
+ if self.mask_scale_mode == 'maxpool':
66
+ mask = F.adaptive_max_pool2d(mask, shape)
67
+ else:
68
+ mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode)
69
+ return mask
70
+
71
+ def make_r1_gp(discr_real_pred, real_batch):
72
+ if torch.is_grad_enabled():
73
+ grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0]
74
+ grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean()
75
+ else:
76
+ grad_penalty = 0
77
+ real_batch.requires_grad = False
78
+
79
+ return grad_penalty
80
+
81
+ class NonSaturatingWithR1(BaseAdversarialLoss):
82
+ def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False,
83
+ mask_scale_mode='nearest', extra_mask_weight_for_gen=0,
84
+ use_unmasked_for_gen=True, use_unmasked_for_discr=True):
85
+ self.gp_coef = gp_coef
86
+ self.weight = weight
87
+ # use for discr => use for gen;
88
+ # otherwise we teach only the discr to pay attention to very small difference
89
+ assert use_unmasked_for_gen or (not use_unmasked_for_discr)
90
+ # mask as target => use unmasked for discr:
91
+ # if we don't care about unmasked regions at all
92
+ # then it doesn't matter if the value of mask_as_fake_target is true or false
93
+ assert use_unmasked_for_discr or (not mask_as_fake_target)
94
+ self.use_unmasked_for_gen = use_unmasked_for_gen
95
+ self.use_unmasked_for_discr = use_unmasked_for_discr
96
+ self.mask_as_fake_target = mask_as_fake_target
97
+ self.allow_scale_mask = allow_scale_mask
98
+ self.mask_scale_mode = mask_scale_mode
99
+ self.extra_mask_weight_for_gen = extra_mask_weight_for_gen
100
+
101
+ def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
102
+ discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
103
+ mask=None) \
104
+ -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
105
+ fake_loss = F.softplus(-discr_fake_pred)
106
+ if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \
107
+ not self.use_unmasked_for_gen: # == if masked region should be treated differently
108
+ mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
109
+ if not self.use_unmasked_for_gen:
110
+ fake_loss = fake_loss * mask
111
+ else:
112
+ pixel_weights = 1 + mask * self.extra_mask_weight_for_gen
113
+ fake_loss = fake_loss * pixel_weights
114
+
115
+ return fake_loss.mean() * self.weight, dict()
116
+
117
+ def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
118
+ generator: nn.Module, discriminator: nn.Module):
119
+ real_batch.requires_grad = True
120
+
121
+ def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
122
+ discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
123
+ mask=None) \
124
+ -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
125
+
126
+ real_loss = F.softplus(-discr_real_pred)
127
+ grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef
128
+ fake_loss = F.softplus(discr_fake_pred)
129
+
130
+ if not self.use_unmasked_for_discr or self.mask_as_fake_target:
131
+ # == if masked region should be treated differently
132
+ mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
133
+ # use_unmasked_for_discr=False only makes sense for fakes;
134
+ # for reals there is no difference beetween two regions
135
+ fake_loss = fake_loss * mask
136
+ if self.mask_as_fake_target:
137
+ fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred)
138
+
139
+ sum_discr_loss = real_loss + grad_penalty + fake_loss
140
+ metrics = dict(discr_real_out=discr_real_pred.mean(),
141
+ discr_fake_out=discr_fake_pred.mean(),
142
+ discr_real_gp=grad_penalty)
143
+ return sum_discr_loss.mean(), metrics
144
+
145
+ class BCELoss(BaseAdversarialLoss):
146
+ def __init__(self, weight):
147
+ self.weight = weight
148
+ self.bce_loss = nn.BCEWithLogitsLoss()
149
+
150
+ def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
151
+ real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device)
152
+ fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight
153
+ return fake_loss, dict()
154
+
155
+ def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
156
+ generator: nn.Module, discriminator: nn.Module):
157
+ real_batch.requires_grad = True
158
+
159
+ def discriminator_loss(self,
160
+ mask: torch.Tensor,
161
+ discr_real_pred: torch.Tensor,
162
+ discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
163
+
164
+ real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device)
165
+ sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2
166
+ metrics = dict(discr_real_out=discr_real_pred.mean(),
167
+ discr_fake_out=discr_fake_pred.mean(),
168
+ discr_real_gp=0)
169
+ return sum_discr_loss, metrics
170
+
171
+
172
+ def make_discrim_loss(kind, **kwargs):
173
+ if kind == 'r1':
174
+ return NonSaturatingWithR1(**kwargs)
175
+ elif kind == 'bce':
176
+ return BCELoss(**kwargs)
177
+ raise ValueError(f'Unknown adversarial loss kind {kind}')
annotator/lama/saicinpainting/training/losses/constants.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ weights = {"ade20k":
2
+ [6.34517766497462,
3
+ 9.328358208955224,
4
+ 11.389521640091116,
5
+ 16.10305958132045,
6
+ 20.833333333333332,
7
+ 22.22222222222222,
8
+ 25.125628140703515,
9
+ 43.29004329004329,
10
+ 50.5050505050505,
11
+ 54.6448087431694,
12
+ 55.24861878453038,
13
+ 60.24096385542168,
14
+ 62.5,
15
+ 66.2251655629139,
16
+ 84.74576271186442,
17
+ 90.90909090909092,
18
+ 91.74311926605505,
19
+ 96.15384615384616,
20
+ 96.15384615384616,
21
+ 97.08737864077669,
22
+ 102.04081632653062,
23
+ 135.13513513513513,
24
+ 149.2537313432836,
25
+ 153.84615384615384,
26
+ 163.93442622950818,
27
+ 166.66666666666666,
28
+ 188.67924528301887,
29
+ 192.30769230769232,
30
+ 217.3913043478261,
31
+ 227.27272727272725,
32
+ 227.27272727272725,
33
+ 227.27272727272725,
34
+ 303.03030303030306,
35
+ 322.5806451612903,
36
+ 333.3333333333333,
37
+ 370.3703703703703,
38
+ 384.61538461538464,
39
+ 416.6666666666667,
40
+ 416.6666666666667,
41
+ 434.7826086956522,
42
+ 434.7826086956522,
43
+ 454.5454545454545,
44
+ 454.5454545454545,
45
+ 500.0,
46
+ 526.3157894736842,
47
+ 526.3157894736842,
48
+ 555.5555555555555,
49
+ 555.5555555555555,
50
+ 555.5555555555555,
51
+ 555.5555555555555,
52
+ 555.5555555555555,
53
+ 555.5555555555555,
54
+ 555.5555555555555,
55
+ 588.2352941176471,
56
+ 588.2352941176471,
57
+ 588.2352941176471,
58
+ 588.2352941176471,
59
+ 588.2352941176471,
60
+ 666.6666666666666,
61
+ 666.6666666666666,
62
+ 666.6666666666666,
63
+ 666.6666666666666,
64
+ 714.2857142857143,
65
+ 714.2857142857143,
66
+ 714.2857142857143,
67
+ 714.2857142857143,
68
+ 714.2857142857143,
69
+ 769.2307692307693,
70
+ 769.2307692307693,
71
+ 769.2307692307693,
72
+ 833.3333333333334,
73
+ 833.3333333333334,
74
+ 833.3333333333334,
75
+ 833.3333333333334,
76
+ 909.090909090909,
77
+ 1000.0,
78
+ 1111.111111111111,
79
+ 1111.111111111111,
80
+ 1111.111111111111,
81
+ 1111.111111111111,
82
+ 1111.111111111111,
83
+ 1250.0,
84
+ 1250.0,
85
+ 1250.0,
86
+ 1250.0,
87
+ 1250.0,
88
+ 1428.5714285714287,
89
+ 1428.5714285714287,
90
+ 1428.5714285714287,
91
+ 1428.5714285714287,
92
+ 1428.5714285714287,
93
+ 1428.5714285714287,
94
+ 1428.5714285714287,
95
+ 1666.6666666666667,
96
+ 1666.6666666666667,
97
+ 1666.6666666666667,
98
+ 1666.6666666666667,
99
+ 1666.6666666666667,
100
+ 1666.6666666666667,
101
+ 1666.6666666666667,
102
+ 1666.6666666666667,
103
+ 1666.6666666666667,
104
+ 1666.6666666666667,
105
+ 1666.6666666666667,
106
+ 2000.0,
107
+ 2000.0,
108
+ 2000.0,
109
+ 2000.0,
110
+ 2000.0,
111
+ 2000.0,
112
+ 2000.0,
113
+ 2000.0,
114
+ 2000.0,
115
+ 2000.0,
116
+ 2000.0,
117
+ 2000.0,
118
+ 2000.0,
119
+ 2000.0,
120
+ 2000.0,
121
+ 2000.0,
122
+ 2000.0,
123
+ 2500.0,
124
+ 2500.0,
125
+ 2500.0,
126
+ 2500.0,
127
+ 2500.0,
128
+ 2500.0,
129
+ 2500.0,
130
+ 2500.0,
131
+ 2500.0,
132
+ 2500.0,
133
+ 2500.0,
134
+ 2500.0,
135
+ 2500.0,
136
+ 3333.3333333333335,
137
+ 3333.3333333333335,
138
+ 3333.3333333333335,
139
+ 3333.3333333333335,
140
+ 3333.3333333333335,
141
+ 3333.3333333333335,
142
+ 3333.3333333333335,
143
+ 3333.3333333333335,
144
+ 3333.3333333333335,
145
+ 3333.3333333333335,
146
+ 3333.3333333333335,
147
+ 3333.3333333333335,
148
+ 3333.3333333333335,
149
+ 5000.0,
150
+ 5000.0,
151
+ 5000.0]
152
+ }
annotator/lama/saicinpainting/training/losses/distance_weighting.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ from annotator.lama.saicinpainting.training.losses.perceptual import IMAGENET_STD, IMAGENET_MEAN
7
+
8
+
9
+ def dummy_distance_weighter(real_img, pred_img, mask):
10
+ return mask
11
+
12
+
13
+ def get_gauss_kernel(kernel_size, width_factor=1):
14
+ coords = torch.stack(torch.meshgrid(torch.arange(kernel_size),
15
+ torch.arange(kernel_size)),
16
+ dim=0).float()
17
+ diff = torch.exp(-((coords - kernel_size // 2) ** 2).sum(0) / kernel_size / width_factor)
18
+ diff /= diff.sum()
19
+ return diff
20
+
21
+
22
+ class BlurMask(nn.Module):
23
+ def __init__(self, kernel_size=5, width_factor=1):
24
+ super().__init__()
25
+ self.filter = nn.Conv2d(1, 1, kernel_size, padding=kernel_size // 2, padding_mode='replicate', bias=False)
26
+ self.filter.weight.data.copy_(get_gauss_kernel(kernel_size, width_factor=width_factor))
27
+
28
+ def forward(self, real_img, pred_img, mask):
29
+ with torch.no_grad():
30
+ result = self.filter(mask) * mask
31
+ return result
32
+
33
+
34
+ class EmulatedEDTMask(nn.Module):
35
+ def __init__(self, dilate_kernel_size=5, blur_kernel_size=5, width_factor=1):
36
+ super().__init__()
37
+ self.dilate_filter = nn.Conv2d(1, 1, dilate_kernel_size, padding=dilate_kernel_size// 2, padding_mode='replicate',
38
+ bias=False)
39
+ self.dilate_filter.weight.data.copy_(torch.ones(1, 1, dilate_kernel_size, dilate_kernel_size, dtype=torch.float))
40
+ self.blur_filter = nn.Conv2d(1, 1, blur_kernel_size, padding=blur_kernel_size // 2, padding_mode='replicate', bias=False)
41
+ self.blur_filter.weight.data.copy_(get_gauss_kernel(blur_kernel_size, width_factor=width_factor))
42
+
43
+ def forward(self, real_img, pred_img, mask):
44
+ with torch.no_grad():
45
+ known_mask = 1 - mask
46
+ dilated_known_mask = (self.dilate_filter(known_mask) > 1).float()
47
+ result = self.blur_filter(1 - dilated_known_mask) * mask
48
+ return result
49
+
50
+
51
+ class PropagatePerceptualSim(nn.Module):
52
+ def __init__(self, level=2, max_iters=10, temperature=500, erode_mask_size=3):
53
+ super().__init__()
54
+ vgg = torchvision.models.vgg19(pretrained=True).features
55
+ vgg_avg_pooling = []
56
+
57
+ for weights in vgg.parameters():
58
+ weights.requires_grad = False
59
+
60
+ cur_level_i = 0
61
+ for module in vgg.modules():
62
+ if module.__class__.__name__ == 'Sequential':
63
+ continue
64
+ elif module.__class__.__name__ == 'MaxPool2d':
65
+ vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
66
+ else:
67
+ vgg_avg_pooling.append(module)
68
+ if module.__class__.__name__ == 'ReLU':
69
+ cur_level_i += 1
70
+ if cur_level_i == level:
71
+ break
72
+
73
+ self.features = nn.Sequential(*vgg_avg_pooling)
74
+
75
+ self.max_iters = max_iters
76
+ self.temperature = temperature
77
+ self.do_erode = erode_mask_size > 0
78
+ if self.do_erode:
79
+ self.erode_mask = nn.Conv2d(1, 1, erode_mask_size, padding=erode_mask_size // 2, bias=False)
80
+ self.erode_mask.weight.data.fill_(1)
81
+
82
+ def forward(self, real_img, pred_img, mask):
83
+ with torch.no_grad():
84
+ real_img = (real_img - IMAGENET_MEAN.to(real_img)) / IMAGENET_STD.to(real_img)
85
+ real_feats = self.features(real_img)
86
+
87
+ vertical_sim = torch.exp(-(real_feats[:, :, 1:] - real_feats[:, :, :-1]).pow(2).sum(1, keepdim=True)
88
+ / self.temperature)
89
+ horizontal_sim = torch.exp(-(real_feats[:, :, :, 1:] - real_feats[:, :, :, :-1]).pow(2).sum(1, keepdim=True)
90
+ / self.temperature)
91
+
92
+ mask_scaled = F.interpolate(mask, size=real_feats.shape[-2:], mode='bilinear', align_corners=False)
93
+ if self.do_erode:
94
+ mask_scaled = (self.erode_mask(mask_scaled) > 1).float()
95
+
96
+ cur_knowness = 1 - mask_scaled
97
+
98
+ for iter_i in range(self.max_iters):
99
+ new_top_knowness = F.pad(cur_knowness[:, :, :-1] * vertical_sim, (0, 0, 1, 0), mode='replicate')
100
+ new_bottom_knowness = F.pad(cur_knowness[:, :, 1:] * vertical_sim, (0, 0, 0, 1), mode='replicate')
101
+
102
+ new_left_knowness = F.pad(cur_knowness[:, :, :, :-1] * horizontal_sim, (1, 0, 0, 0), mode='replicate')
103
+ new_right_knowness = F.pad(cur_knowness[:, :, :, 1:] * horizontal_sim, (0, 1, 0, 0), mode='replicate')
104
+
105
+ new_knowness = torch.stack([new_top_knowness, new_bottom_knowness,
106
+ new_left_knowness, new_right_knowness],
107
+ dim=0).max(0).values
108
+
109
+ cur_knowness = torch.max(cur_knowness, new_knowness)
110
+
111
+ cur_knowness = F.interpolate(cur_knowness, size=mask.shape[-2:], mode='bilinear')
112
+ result = torch.min(mask, 1 - cur_knowness)
113
+
114
+ return result
115
+
116
+
117
+ def make_mask_distance_weighter(kind='none', **kwargs):
118
+ if kind == 'none':
119
+ return dummy_distance_weighter
120
+ if kind == 'blur':
121
+ return BlurMask(**kwargs)
122
+ if kind == 'edt':
123
+ return EmulatedEDTMask(**kwargs)
124
+ if kind == 'pps':
125
+ return PropagatePerceptualSim(**kwargs)
126
+ raise ValueError(f'Unknown mask distance weighter kind {kind}')
annotator/lama/saicinpainting/training/losses/feature_matching.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def masked_l2_loss(pred, target, mask, weight_known, weight_missing):
8
+ per_pixel_l2 = F.mse_loss(pred, target, reduction='none')
9
+ pixel_weights = mask * weight_missing + (1 - mask) * weight_known
10
+ return (pixel_weights * per_pixel_l2).mean()
11
+
12
+
13
+ def masked_l1_loss(pred, target, mask, weight_known, weight_missing):
14
+ per_pixel_l1 = F.l1_loss(pred, target, reduction='none')
15
+ pixel_weights = mask * weight_missing + (1 - mask) * weight_known
16
+ return (pixel_weights * per_pixel_l1).mean()
17
+
18
+
19
+ def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None):
20
+ if mask is None:
21
+ res = torch.stack([F.mse_loss(fake_feat, target_feat)
22
+ for fake_feat, target_feat in zip(fake_features, target_features)]).mean()
23
+ else:
24
+ res = 0
25
+ norm = 0
26
+ for fake_feat, target_feat in zip(fake_features, target_features):
27
+ cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False)
28
+ error_weights = 1 - cur_mask
29
+ cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean()
30
+ res = res + cur_val
31
+ norm += 1
32
+ res = res / norm
33
+ return res
annotator/lama/saicinpainting/training/losses/perceptual.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ # from models.ade20k import ModelBuilder
7
+ from annotator.lama.saicinpainting.utils import check_and_warn_input_range
8
+
9
+
10
+ IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
11
+ IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
12
+
13
+
14
+ class PerceptualLoss(nn.Module):
15
+ def __init__(self, normalize_inputs=True):
16
+ super(PerceptualLoss, self).__init__()
17
+
18
+ self.normalize_inputs = normalize_inputs
19
+ self.mean_ = IMAGENET_MEAN
20
+ self.std_ = IMAGENET_STD
21
+
22
+ vgg = torchvision.models.vgg19(pretrained=True).features
23
+ vgg_avg_pooling = []
24
+
25
+ for weights in vgg.parameters():
26
+ weights.requires_grad = False
27
+
28
+ for module in vgg.modules():
29
+ if module.__class__.__name__ == 'Sequential':
30
+ continue
31
+ elif module.__class__.__name__ == 'MaxPool2d':
32
+ vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
33
+ else:
34
+ vgg_avg_pooling.append(module)
35
+
36
+ self.vgg = nn.Sequential(*vgg_avg_pooling)
37
+
38
+ def do_normalize_inputs(self, x):
39
+ return (x - self.mean_.to(x.device)) / self.std_.to(x.device)
40
+
41
+ def partial_losses(self, input, target, mask=None):
42
+ check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses')
43
+
44
+ # we expect input and target to be in [0, 1] range
45
+ losses = []
46
+
47
+ if self.normalize_inputs:
48
+ features_input = self.do_normalize_inputs(input)
49
+ features_target = self.do_normalize_inputs(target)
50
+ else:
51
+ features_input = input
52
+ features_target = target
53
+
54
+ for layer in self.vgg[:30]:
55
+
56
+ features_input = layer(features_input)
57
+ features_target = layer(features_target)
58
+
59
+ if layer.__class__.__name__ == 'ReLU':
60
+ loss = F.mse_loss(features_input, features_target, reduction='none')
61
+
62
+ if mask is not None:
63
+ cur_mask = F.interpolate(mask, size=features_input.shape[-2:],
64
+ mode='bilinear', align_corners=False)
65
+ loss = loss * (1 - cur_mask)
66
+
67
+ loss = loss.mean(dim=tuple(range(1, len(loss.shape))))
68
+ losses.append(loss)
69
+
70
+ return losses
71
+
72
+ def forward(self, input, target, mask=None):
73
+ losses = self.partial_losses(input, target, mask=mask)
74
+ return torch.stack(losses).sum(dim=0)
75
+
76
+ def get_global_features(self, input):
77
+ check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features')
78
+
79
+ if self.normalize_inputs:
80
+ features_input = self.do_normalize_inputs(input)
81
+ else:
82
+ features_input = input
83
+
84
+ features_input = self.vgg(features_input)
85
+ return features_input
86
+
87
+
88
+ class ResNetPL(nn.Module):
89
+ def __init__(self, weight=1,
90
+ weights_path=None, arch_encoder='resnet50dilated', segmentation=True):
91
+ super().__init__()
92
+ self.impl = ModelBuilder.get_encoder(weights_path=weights_path,
93
+ arch_encoder=arch_encoder,
94
+ arch_decoder='ppm_deepsup',
95
+ fc_dim=2048,
96
+ segmentation=segmentation)
97
+ self.impl.eval()
98
+ for w in self.impl.parameters():
99
+ w.requires_grad_(False)
100
+
101
+ self.weight = weight
102
+
103
+ def forward(self, pred, target):
104
+ pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred)
105
+ target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target)
106
+
107
+ pred_feats = self.impl(pred, return_feature_maps=True)
108
+ target_feats = self.impl(target, return_feature_maps=True)
109
+
110
+ result = torch.stack([F.mse_loss(cur_pred, cur_target)
111
+ for cur_pred, cur_target
112
+ in zip(pred_feats, target_feats)]).sum() * self.weight
113
+ return result
annotator/lama/saicinpainting/training/losses/segmentation.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .constants import weights as constant_weights
6
+
7
+
8
+ class CrossEntropy2d(nn.Module):
9
+ def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs):
10
+ """
11
+ weight (Tensor, optional): a manual rescaling weight given to each class.
12
+ If given, has to be a Tensor of size "nclasses"
13
+ """
14
+ super(CrossEntropy2d, self).__init__()
15
+ self.reduction = reduction
16
+ self.ignore_label = ignore_label
17
+ self.weights = weights
18
+ if self.weights is not None:
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+ self.weights = torch.FloatTensor(constant_weights[weights]).to(device)
21
+
22
+ def forward(self, predict, target):
23
+ """
24
+ Args:
25
+ predict:(n, c, h, w)
26
+ target:(n, 1, h, w)
27
+ """
28
+ target = target.long()
29
+ assert not target.requires_grad
30
+ assert predict.dim() == 4, "{0}".format(predict.size())
31
+ assert target.dim() == 4, "{0}".format(target.size())
32
+ assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
33
+ assert target.size(1) == 1, "{0}".format(target.size(1))
34
+ assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2))
35
+ assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3))
36
+ target = target.squeeze(1)
37
+ n, c, h, w = predict.size()
38
+ target_mask = (target >= 0) * (target != self.ignore_label)
39
+ target = target[target_mask]
40
+ predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
41
+ predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
42
+ loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction)
43
+ return loss
annotator/lama/saicinpainting/training/losses/style_loss.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+
5
+
6
+ class PerceptualLoss(nn.Module):
7
+ r"""
8
+ Perceptual loss, VGG-based
9
+ https://arxiv.org/abs/1603.08155
10
+ https://github.com/dxyang/StyleTransfer/blob/master/utils.py
11
+ """
12
+
13
+ def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
14
+ super(PerceptualLoss, self).__init__()
15
+ self.add_module('vgg', VGG19())
16
+ self.criterion = torch.nn.L1Loss()
17
+ self.weights = weights
18
+
19
+ def __call__(self, x, y):
20
+ # Compute features
21
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
22
+
23
+ content_loss = 0.0
24
+ content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
25
+ content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
26
+ content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
27
+ content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
28
+ content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
29
+
30
+
31
+ return content_loss
32
+
33
+
34
+ class VGG19(torch.nn.Module):
35
+ def __init__(self):
36
+ super(VGG19, self).__init__()
37
+ features = models.vgg19(pretrained=True).features
38
+ self.relu1_1 = torch.nn.Sequential()
39
+ self.relu1_2 = torch.nn.Sequential()
40
+
41
+ self.relu2_1 = torch.nn.Sequential()
42
+ self.relu2_2 = torch.nn.Sequential()
43
+
44
+ self.relu3_1 = torch.nn.Sequential()
45
+ self.relu3_2 = torch.nn.Sequential()
46
+ self.relu3_3 = torch.nn.Sequential()
47
+ self.relu3_4 = torch.nn.Sequential()
48
+
49
+ self.relu4_1 = torch.nn.Sequential()
50
+ self.relu4_2 = torch.nn.Sequential()
51
+ self.relu4_3 = torch.nn.Sequential()
52
+ self.relu4_4 = torch.nn.Sequential()
53
+
54
+ self.relu5_1 = torch.nn.Sequential()
55
+ self.relu5_2 = torch.nn.Sequential()
56
+ self.relu5_3 = torch.nn.Sequential()
57
+ self.relu5_4 = torch.nn.Sequential()
58
+
59
+ for x in range(2):
60
+ self.relu1_1.add_module(str(x), features[x])
61
+
62
+ for x in range(2, 4):
63
+ self.relu1_2.add_module(str(x), features[x])
64
+
65
+ for x in range(4, 7):
66
+ self.relu2_1.add_module(str(x), features[x])
67
+
68
+ for x in range(7, 9):
69
+ self.relu2_2.add_module(str(x), features[x])
70
+
71
+ for x in range(9, 12):
72
+ self.relu3_1.add_module(str(x), features[x])
73
+
74
+ for x in range(12, 14):
75
+ self.relu3_2.add_module(str(x), features[x])
76
+
77
+ for x in range(14, 16):
78
+ self.relu3_2.add_module(str(x), features[x])
79
+
80
+ for x in range(16, 18):
81
+ self.relu3_4.add_module(str(x), features[x])
82
+
83
+ for x in range(18, 21):
84
+ self.relu4_1.add_module(str(x), features[x])
85
+
86
+ for x in range(21, 23):
87
+ self.relu4_2.add_module(str(x), features[x])
88
+
89
+ for x in range(23, 25):
90
+ self.relu4_3.add_module(str(x), features[x])
91
+
92
+ for x in range(25, 27):
93
+ self.relu4_4.add_module(str(x), features[x])
94
+
95
+ for x in range(27, 30):
96
+ self.relu5_1.add_module(str(x), features[x])
97
+
98
+ for x in range(30, 32):
99
+ self.relu5_2.add_module(str(x), features[x])
100
+
101
+ for x in range(32, 34):
102
+ self.relu5_3.add_module(str(x), features[x])
103
+
104
+ for x in range(34, 36):
105
+ self.relu5_4.add_module(str(x), features[x])
106
+
107
+ # don't need the gradients, just want the features
108
+ for param in self.parameters():
109
+ param.requires_grad = False
110
+
111
+ def forward(self, x):
112
+ relu1_1 = self.relu1_1(x)
113
+ relu1_2 = self.relu1_2(relu1_1)
114
+
115
+ relu2_1 = self.relu2_1(relu1_2)
116
+ relu2_2 = self.relu2_2(relu2_1)
117
+
118
+ relu3_1 = self.relu3_1(relu2_2)
119
+ relu3_2 = self.relu3_2(relu3_1)
120
+ relu3_3 = self.relu3_3(relu3_2)
121
+ relu3_4 = self.relu3_4(relu3_3)
122
+
123
+ relu4_1 = self.relu4_1(relu3_4)
124
+ relu4_2 = self.relu4_2(relu4_1)
125
+ relu4_3 = self.relu4_3(relu4_2)
126
+ relu4_4 = self.relu4_4(relu4_3)
127
+
128
+ relu5_1 = self.relu5_1(relu4_4)
129
+ relu5_2 = self.relu5_2(relu5_1)
130
+ relu5_3 = self.relu5_3(relu5_2)
131
+ relu5_4 = self.relu5_4(relu5_3)
132
+
133
+ out = {
134
+ 'relu1_1': relu1_1,
135
+ 'relu1_2': relu1_2,
136
+
137
+ 'relu2_1': relu2_1,
138
+ 'relu2_2': relu2_2,
139
+
140
+ 'relu3_1': relu3_1,
141
+ 'relu3_2': relu3_2,
142
+ 'relu3_3': relu3_3,
143
+ 'relu3_4': relu3_4,
144
+
145
+ 'relu4_1': relu4_1,
146
+ 'relu4_2': relu4_2,
147
+ 'relu4_3': relu4_3,
148
+ 'relu4_4': relu4_4,
149
+
150
+ 'relu5_1': relu5_1,
151
+ 'relu5_2': relu5_2,
152
+ 'relu5_3': relu5_3,
153
+ 'relu5_4': relu5_4,
154
+ }
155
+ return out
annotator/lama/saicinpainting/training/modules/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from annotator.lama.saicinpainting.training.modules.ffc import FFCResNetGenerator
4
+ from annotator.lama.saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \
5
+ NLayerDiscriminator, MultidilatedNLayerDiscriminator
6
+
7
+ def make_generator(config, kind, **kwargs):
8
+ logging.info(f'Make generator {kind}')
9
+
10
+ if kind == 'pix2pixhd_multidilated':
11
+ return MultiDilatedGlobalGenerator(**kwargs)
12
+
13
+ if kind == 'pix2pixhd_global':
14
+ return GlobalGenerator(**kwargs)
15
+
16
+ if kind == 'ffc_resnet':
17
+ return FFCResNetGenerator(**kwargs)
18
+
19
+ raise ValueError(f'Unknown generator kind {kind}')
20
+
21
+
22
+ def make_discriminator(kind, **kwargs):
23
+ logging.info(f'Make discriminator {kind}')
24
+
25
+ if kind == 'pix2pixhd_nlayer_multidilated':
26
+ return MultidilatedNLayerDiscriminator(**kwargs)
27
+
28
+ if kind == 'pix2pixhd_nlayer':
29
+ return NLayerDiscriminator(**kwargs)
30
+
31
+ raise ValueError(f'Unknown discriminator kind {kind}')
annotator/lama/saicinpainting/training/modules/base.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Tuple, List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from annotator.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
8
+ from annotator.lama.saicinpainting.training.modules.multidilated_conv import MultidilatedConv
9
+
10
+
11
+ class BaseDiscriminator(nn.Module):
12
+ @abc.abstractmethod
13
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
14
+ """
15
+ Predict scores and get intermediate activations. Useful for feature matching loss
16
+ :return tuple (scores, list of intermediate activations)
17
+ """
18
+ raise NotImplemented()
19
+
20
+
21
+ def get_conv_block_ctor(kind='default'):
22
+ if not isinstance(kind, str):
23
+ return kind
24
+ if kind == 'default':
25
+ return nn.Conv2d
26
+ if kind == 'depthwise':
27
+ return DepthWiseSeperableConv
28
+ if kind == 'multidilated':
29
+ return MultidilatedConv
30
+ raise ValueError(f'Unknown convolutional block kind {kind}')
31
+
32
+
33
+ def get_norm_layer(kind='bn'):
34
+ if not isinstance(kind, str):
35
+ return kind
36
+ if kind == 'bn':
37
+ return nn.BatchNorm2d
38
+ if kind == 'in':
39
+ return nn.InstanceNorm2d
40
+ raise ValueError(f'Unknown norm block kind {kind}')
41
+
42
+
43
+ def get_activation(kind='tanh'):
44
+ if kind == 'tanh':
45
+ return nn.Tanh()
46
+ if kind == 'sigmoid':
47
+ return nn.Sigmoid()
48
+ if kind is False:
49
+ return nn.Identity()
50
+ raise ValueError(f'Unknown activation kind {kind}')
51
+
52
+
53
+ class SimpleMultiStepGenerator(nn.Module):
54
+ def __init__(self, steps: List[nn.Module]):
55
+ super().__init__()
56
+ self.steps = nn.ModuleList(steps)
57
+
58
+ def forward(self, x):
59
+ cur_in = x
60
+ outs = []
61
+ for step in self.steps:
62
+ cur_out = step(cur_in)
63
+ outs.append(cur_out)
64
+ cur_in = torch.cat((cur_in, cur_out), dim=1)
65
+ return torch.cat(outs[::-1], dim=1)
66
+
67
+ def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features):
68
+ if kind == 'convtranspose':
69
+ return [nn.ConvTranspose2d(min(max_features, ngf * mult),
70
+ min(max_features, int(ngf * mult / 2)),
71
+ kernel_size=3, stride=2, padding=1, output_padding=1),
72
+ norm_layer(min(max_features, int(ngf * mult / 2))), activation]
73
+ elif kind == 'bilinear':
74
+ return [nn.Upsample(scale_factor=2, mode='bilinear'),
75
+ DepthWiseSeperableConv(min(max_features, ngf * mult),
76
+ min(max_features, int(ngf * mult / 2)),
77
+ kernel_size=3, stride=1, padding=1),
78
+ norm_layer(min(max_features, int(ngf * mult / 2))), activation]
79
+ else:
80
+ raise Exception(f"Invalid deconv kind: {kind}")
annotator/lama/saicinpainting/training/modules/depthwise_sep_conv.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class DepthWiseSeperableConv(nn.Module):
5
+ def __init__(self, in_dim, out_dim, *args, **kwargs):
6
+ super().__init__()
7
+ if 'groups' in kwargs:
8
+ # ignoring groups for Depthwise Sep Conv
9
+ del kwargs['groups']
10
+
11
+ self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs)
12
+ self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1)
13
+
14
+ def forward(self, x):
15
+ out = self.depthwise(x)
16
+ out = self.pointwise(out)
17
+ return out
annotator/lama/saicinpainting/training/modules/fake_fakes.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from kornia import SamplePadding
3
+ from kornia.augmentation import RandomAffine, CenterCrop
4
+
5
+
6
+ class FakeFakesGenerator:
7
+ def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2):
8
+ self.grad_aug = RandomAffine(degrees=360,
9
+ translate=0.2,
10
+ padding_mode=SamplePadding.REFLECTION,
11
+ keepdim=False,
12
+ p=1)
13
+ self.img_aug = RandomAffine(degrees=img_aug_degree,
14
+ translate=img_aug_translate,
15
+ padding_mode=SamplePadding.REFLECTION,
16
+ keepdim=True,
17
+ p=1)
18
+ self.aug_proba = aug_proba
19
+
20
+ def __call__(self, input_images, masks):
21
+ blend_masks = self._fill_masks_with_gradient(masks)
22
+ blend_target = self._make_blend_target(input_images)
23
+ result = input_images * (1 - blend_masks) + blend_target * blend_masks
24
+ return result, blend_masks
25
+
26
+ def _make_blend_target(self, input_images):
27
+ batch_size = input_images.shape[0]
28
+ permuted = input_images[torch.randperm(batch_size)]
29
+ augmented = self.img_aug(input_images)
30
+ is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float()
31
+ result = augmented * is_aug + permuted * (1 - is_aug)
32
+ return result
33
+
34
+ def _fill_masks_with_gradient(self, masks):
35
+ batch_size, _, height, width = masks.shape
36
+ grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \
37
+ .view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2)
38
+ grad = self.grad_aug(grad)
39
+ grad = CenterCrop((height, width))(grad)
40
+ grad *= masks
41
+
42
+ grad_for_min = grad + (1 - masks) * 10
43
+ grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None]
44
+ grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6
45
+ grad.clamp_(min=0, max=1)
46
+
47
+ return grad
annotator/lama/saicinpainting/training/modules/ffc.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fast Fourier Convolution NeurIPS 2020
2
+ # original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
3
+ # paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from annotator.lama.saicinpainting.training.modules.base import get_activation, BaseDiscriminator
11
+ from annotator.lama.saicinpainting.training.modules.spatial_transform import LearnableSpatialTransformWrapper
12
+ from annotator.lama.saicinpainting.training.modules.squeeze_excitation import SELayer
13
+ from annotator.lama.saicinpainting.utils import get_shape
14
+
15
+
16
+ class FFCSE_block(nn.Module):
17
+
18
+ def __init__(self, channels, ratio_g):
19
+ super(FFCSE_block, self).__init__()
20
+ in_cg = int(channels * ratio_g)
21
+ in_cl = channels - in_cg
22
+ r = 16
23
+
24
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
25
+ self.conv1 = nn.Conv2d(channels, channels // r,
26
+ kernel_size=1, bias=True)
27
+ self.relu1 = nn.ReLU(inplace=True)
28
+ self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
29
+ channels // r, in_cl, kernel_size=1, bias=True)
30
+ self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
31
+ channels // r, in_cg, kernel_size=1, bias=True)
32
+ self.sigmoid = nn.Sigmoid()
33
+
34
+ def forward(self, x):
35
+ x = x if type(x) is tuple else (x, 0)
36
+ id_l, id_g = x
37
+
38
+ x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
39
+ x = self.avgpool(x)
40
+ x = self.relu1(self.conv1(x))
41
+
42
+ x_l = 0 if self.conv_a2l is None else id_l * \
43
+ self.sigmoid(self.conv_a2l(x))
44
+ x_g = 0 if self.conv_a2g is None else id_g * \
45
+ self.sigmoid(self.conv_a2g(x))
46
+ return x_l, x_g
47
+
48
+
49
+ class FourierUnit(nn.Module):
50
+
51
+ def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
52
+ spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
53
+ # bn_layer not used
54
+ super(FourierUnit, self).__init__()
55
+ self.groups = groups
56
+
57
+ self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
58
+ out_channels=out_channels * 2,
59
+ kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
60
+ self.bn = torch.nn.BatchNorm2d(out_channels * 2)
61
+ self.relu = torch.nn.ReLU(inplace=True)
62
+
63
+ # squeeze and excitation block
64
+ self.use_se = use_se
65
+ if use_se:
66
+ if se_kwargs is None:
67
+ se_kwargs = {}
68
+ self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
69
+
70
+ self.spatial_scale_factor = spatial_scale_factor
71
+ self.spatial_scale_mode = spatial_scale_mode
72
+ self.spectral_pos_encoding = spectral_pos_encoding
73
+ self.ffc3d = ffc3d
74
+ self.fft_norm = fft_norm
75
+
76
+ def forward(self, x):
77
+ batch = x.shape[0]
78
+
79
+ if self.spatial_scale_factor is not None:
80
+ orig_size = x.shape[-2:]
81
+ x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)
82
+
83
+ r_size = x.size()
84
+ # (batch, c, h, w/2+1, 2)
85
+ fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
86
+ ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
87
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
88
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
89
+ ffted = ffted.view((batch, -1,) + ffted.size()[3:])
90
+
91
+ if self.spectral_pos_encoding:
92
+ height, width = ffted.shape[-2:]
93
+ coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
94
+ coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
95
+ ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
96
+
97
+ if self.use_se:
98
+ ffted = self.se(ffted)
99
+
100
+ ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
101
+ ffted = self.relu(self.bn(ffted))
102
+
103
+ ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
104
+ 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
105
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
106
+
107
+ ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
108
+ output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
109
+
110
+ if self.spatial_scale_factor is not None:
111
+ output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
112
+
113
+ return output
114
+
115
+
116
+ class SeparableFourierUnit(nn.Module):
117
+
118
+ def __init__(self, in_channels, out_channels, groups=1, kernel_size=3):
119
+ # bn_layer not used
120
+ super(SeparableFourierUnit, self).__init__()
121
+ self.groups = groups
122
+ row_out_channels = out_channels // 2
123
+ col_out_channels = out_channels - row_out_channels
124
+ self.row_conv = torch.nn.Conv2d(in_channels=in_channels * 2,
125
+ out_channels=row_out_channels * 2,
126
+ kernel_size=(kernel_size, 1), # kernel size is always like this, but the data will be transposed
127
+ stride=1, padding=(kernel_size // 2, 0),
128
+ padding_mode='reflect',
129
+ groups=self.groups, bias=False)
130
+ self.col_conv = torch.nn.Conv2d(in_channels=in_channels * 2,
131
+ out_channels=col_out_channels * 2,
132
+ kernel_size=(kernel_size, 1), # kernel size is always like this, but the data will be transposed
133
+ stride=1, padding=(kernel_size // 2, 0),
134
+ padding_mode='reflect',
135
+ groups=self.groups, bias=False)
136
+ self.row_bn = torch.nn.BatchNorm2d(row_out_channels * 2)
137
+ self.col_bn = torch.nn.BatchNorm2d(col_out_channels * 2)
138
+ self.relu = torch.nn.ReLU(inplace=True)
139
+
140
+ def process_branch(self, x, conv, bn):
141
+ batch = x.shape[0]
142
+
143
+ r_size = x.size()
144
+ # (batch, c, h, w/2+1, 2)
145
+ ffted = torch.fft.rfft(x, norm="ortho")
146
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
147
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
148
+ ffted = ffted.view((batch, -1,) + ffted.size()[3:])
149
+
150
+ ffted = self.relu(bn(conv(ffted)))
151
+
152
+ ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
153
+ 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
154
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
155
+
156
+ output = torch.fft.irfft(ffted, s=x.shape[-1:], norm="ortho")
157
+ return output
158
+
159
+
160
+ def forward(self, x):
161
+ rowwise = self.process_branch(x, self.row_conv, self.row_bn)
162
+ colwise = self.process_branch(x.permute(0, 1, 3, 2), self.col_conv, self.col_bn).permute(0, 1, 3, 2)
163
+ out = torch.cat((rowwise, colwise), dim=1)
164
+ return out
165
+
166
+
167
+ class SpectralTransform(nn.Module):
168
+
169
+ def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, separable_fu=False, **fu_kwargs):
170
+ # bn_layer not used
171
+ super(SpectralTransform, self).__init__()
172
+ self.enable_lfu = enable_lfu
173
+ if stride == 2:
174
+ self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
175
+ else:
176
+ self.downsample = nn.Identity()
177
+
178
+ self.stride = stride
179
+ self.conv1 = nn.Sequential(
180
+ nn.Conv2d(in_channels, out_channels //
181
+ 2, kernel_size=1, groups=groups, bias=False),
182
+ nn.BatchNorm2d(out_channels // 2),
183
+ nn.ReLU(inplace=True)
184
+ )
185
+ fu_class = SeparableFourierUnit if separable_fu else FourierUnit
186
+ self.fu = fu_class(
187
+ out_channels // 2, out_channels // 2, groups, **fu_kwargs)
188
+ if self.enable_lfu:
189
+ self.lfu = fu_class(
190
+ out_channels // 2, out_channels // 2, groups)
191
+ self.conv2 = torch.nn.Conv2d(
192
+ out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
193
+
194
+ def forward(self, x):
195
+
196
+ x = self.downsample(x)
197
+ x = self.conv1(x)
198
+ output = self.fu(x)
199
+
200
+ if self.enable_lfu:
201
+ n, c, h, w = x.shape
202
+ split_no = 2
203
+ split_s = h // split_no
204
+ xs = torch.cat(torch.split(
205
+ x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
206
+ xs = torch.cat(torch.split(xs, split_s, dim=-1),
207
+ dim=1).contiguous()
208
+ xs = self.lfu(xs)
209
+ xs = xs.repeat(1, 1, split_no, split_no).contiguous()
210
+ else:
211
+ xs = 0
212
+
213
+ output = self.conv2(x + output + xs)
214
+
215
+ return output
216
+
217
+
218
+ class FFC(nn.Module):
219
+
220
+ def __init__(self, in_channels, out_channels, kernel_size,
221
+ ratio_gin, ratio_gout, stride=1, padding=0,
222
+ dilation=1, groups=1, bias=False, enable_lfu=True,
223
+ padding_type='reflect', gated=False, **spectral_kwargs):
224
+ super(FFC, self).__init__()
225
+
226
+ assert stride == 1 or stride == 2, "Stride should be 1 or 2."
227
+ self.stride = stride
228
+
229
+ in_cg = int(in_channels * ratio_gin)
230
+ in_cl = in_channels - in_cg
231
+ out_cg = int(out_channels * ratio_gout)
232
+ out_cl = out_channels - out_cg
233
+ #groups_g = 1 if groups == 1 else int(groups * ratio_gout)
234
+ #groups_l = 1 if groups == 1 else groups - groups_g
235
+
236
+ self.ratio_gin = ratio_gin
237
+ self.ratio_gout = ratio_gout
238
+ self.global_in_num = in_cg
239
+
240
+ module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
241
+ self.convl2l = module(in_cl, out_cl, kernel_size,
242
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
243
+ module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
244
+ self.convl2g = module(in_cl, out_cg, kernel_size,
245
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
246
+ module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
247
+ self.convg2l = module(in_cg, out_cl, kernel_size,
248
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
249
+ module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
250
+ self.convg2g = module(
251
+ in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
252
+
253
+ self.gated = gated
254
+ module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
255
+ self.gate = module(in_channels, 2, 1)
256
+
257
+ def forward(self, x):
258
+ x_l, x_g = x if type(x) is tuple else (x, 0)
259
+ out_xl, out_xg = 0, 0
260
+
261
+ if self.gated:
262
+ total_input_parts = [x_l]
263
+ if torch.is_tensor(x_g):
264
+ total_input_parts.append(x_g)
265
+ total_input = torch.cat(total_input_parts, dim=1)
266
+
267
+ gates = torch.sigmoid(self.gate(total_input))
268
+ g2l_gate, l2g_gate = gates.chunk(2, dim=1)
269
+ else:
270
+ g2l_gate, l2g_gate = 1, 1
271
+
272
+ if self.ratio_gout != 1:
273
+ out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
274
+ if self.ratio_gout != 0:
275
+ out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
276
+
277
+ return out_xl, out_xg
278
+
279
+
280
+ class FFC_BN_ACT(nn.Module):
281
+
282
+ def __init__(self, in_channels, out_channels,
283
+ kernel_size, ratio_gin, ratio_gout,
284
+ stride=1, padding=0, dilation=1, groups=1, bias=False,
285
+ norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
286
+ padding_type='reflect',
287
+ enable_lfu=True, **kwargs):
288
+ super(FFC_BN_ACT, self).__init__()
289
+ self.ffc = FFC(in_channels, out_channels, kernel_size,
290
+ ratio_gin, ratio_gout, stride, padding, dilation,
291
+ groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
292
+ lnorm = nn.Identity if ratio_gout == 1 else norm_layer
293
+ gnorm = nn.Identity if ratio_gout == 0 else norm_layer
294
+ global_channels = int(out_channels * ratio_gout)
295
+ self.bn_l = lnorm(out_channels - global_channels)
296
+ self.bn_g = gnorm(global_channels)
297
+
298
+ lact = nn.Identity if ratio_gout == 1 else activation_layer
299
+ gact = nn.Identity if ratio_gout == 0 else activation_layer
300
+ self.act_l = lact(inplace=True)
301
+ self.act_g = gact(inplace=True)
302
+
303
+ def forward(self, x):
304
+ x_l, x_g = self.ffc(x)
305
+ x_l = self.act_l(self.bn_l(x_l))
306
+ x_g = self.act_g(self.bn_g(x_g))
307
+ return x_l, x_g
308
+
309
+
310
+ class FFCResnetBlock(nn.Module):
311
+ def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
312
+ spatial_transform_kwargs=None, inline=False, **conv_kwargs):
313
+ super().__init__()
314
+ self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
315
+ norm_layer=norm_layer,
316
+ activation_layer=activation_layer,
317
+ padding_type=padding_type,
318
+ **conv_kwargs)
319
+ self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
320
+ norm_layer=norm_layer,
321
+ activation_layer=activation_layer,
322
+ padding_type=padding_type,
323
+ **conv_kwargs)
324
+ if spatial_transform_kwargs is not None:
325
+ self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs)
326
+ self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs)
327
+ self.inline = inline
328
+
329
+ def forward(self, x):
330
+ if self.inline:
331
+ x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
332
+ else:
333
+ x_l, x_g = x if type(x) is tuple else (x, 0)
334
+
335
+ id_l, id_g = x_l, x_g
336
+
337
+ x_l, x_g = self.conv1((x_l, x_g))
338
+ x_l, x_g = self.conv2((x_l, x_g))
339
+
340
+ x_l, x_g = id_l + x_l, id_g + x_g
341
+ out = x_l, x_g
342
+ if self.inline:
343
+ out = torch.cat(out, dim=1)
344
+ return out
345
+
346
+
347
+ class ConcatTupleLayer(nn.Module):
348
+ def forward(self, x):
349
+ assert isinstance(x, tuple)
350
+ x_l, x_g = x
351
+ assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
352
+ if not torch.is_tensor(x_g):
353
+ return x_l
354
+ return torch.cat(x, dim=1)
355
+
356
+
357
+ class FFCResNetGenerator(nn.Module):
358
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
359
+ padding_type='reflect', activation_layer=nn.ReLU,
360
+ up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True),
361
+ init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={},
362
+ spatial_transform_layers=None, spatial_transform_kwargs={},
363
+ add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}):
364
+ assert (n_blocks >= 0)
365
+ super().__init__()
366
+
367
+ model = [nn.ReflectionPad2d(3),
368
+ FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer,
369
+ activation_layer=activation_layer, **init_conv_kwargs)]
370
+
371
+ ### downsample
372
+ for i in range(n_downsampling):
373
+ mult = 2 ** i
374
+ if i == n_downsampling - 1:
375
+ cur_conv_kwargs = dict(downsample_conv_kwargs)
376
+ cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0)
377
+ else:
378
+ cur_conv_kwargs = downsample_conv_kwargs
379
+ model += [FFC_BN_ACT(min(max_features, ngf * mult),
380
+ min(max_features, ngf * mult * 2),
381
+ kernel_size=3, stride=2, padding=1,
382
+ norm_layer=norm_layer,
383
+ activation_layer=activation_layer,
384
+ **cur_conv_kwargs)]
385
+
386
+ mult = 2 ** n_downsampling
387
+ feats_num_bottleneck = min(max_features, ngf * mult)
388
+
389
+ ### resnet blocks
390
+ for i in range(n_blocks):
391
+ cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer,
392
+ norm_layer=norm_layer, **resnet_conv_kwargs)
393
+ if spatial_transform_layers is not None and i in spatial_transform_layers:
394
+ cur_resblock = LearnableSpatialTransformWrapper(cur_resblock, **spatial_transform_kwargs)
395
+ model += [cur_resblock]
396
+
397
+ model += [ConcatTupleLayer()]
398
+
399
+ ### upsample
400
+ for i in range(n_downsampling):
401
+ mult = 2 ** (n_downsampling - i)
402
+ model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
403
+ min(max_features, int(ngf * mult / 2)),
404
+ kernel_size=3, stride=2, padding=1, output_padding=1),
405
+ up_norm_layer(min(max_features, int(ngf * mult / 2))),
406
+ up_activation]
407
+
408
+ if out_ffc:
409
+ model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer,
410
+ norm_layer=norm_layer, inline=True, **out_ffc_kwargs)]
411
+
412
+ model += [nn.ReflectionPad2d(3),
413
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
414
+ if add_out_act:
415
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
416
+ self.model = nn.Sequential(*model)
417
+
418
+ def forward(self, input):
419
+ return self.model(input)
420
+
421
+
422
+ class FFCNLayerDiscriminator(BaseDiscriminator):
423
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, max_features=512,
424
+ init_conv_kwargs={}, conv_kwargs={}):
425
+ super().__init__()
426
+ self.n_layers = n_layers
427
+
428
+ def _act_ctor(inplace=True):
429
+ return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
430
+
431
+ kw = 3
432
+ padw = int(np.ceil((kw-1.0)/2))
433
+ sequence = [[FFC_BN_ACT(input_nc, ndf, kernel_size=kw, padding=padw, norm_layer=norm_layer,
434
+ activation_layer=_act_ctor, **init_conv_kwargs)]]
435
+
436
+ nf = ndf
437
+ for n in range(1, n_layers):
438
+ nf_prev = nf
439
+ nf = min(nf * 2, max_features)
440
+
441
+ cur_model = [
442
+ FFC_BN_ACT(nf_prev, nf,
443
+ kernel_size=kw, stride=2, padding=padw,
444
+ norm_layer=norm_layer,
445
+ activation_layer=_act_ctor,
446
+ **conv_kwargs)
447
+ ]
448
+ sequence.append(cur_model)
449
+
450
+ nf_prev = nf
451
+ nf = min(nf * 2, 512)
452
+
453
+ cur_model = [
454
+ FFC_BN_ACT(nf_prev, nf,
455
+ kernel_size=kw, stride=1, padding=padw,
456
+ norm_layer=norm_layer,
457
+ activation_layer=lambda *args, **kwargs: nn.LeakyReLU(*args, negative_slope=0.2, **kwargs),
458
+ **conv_kwargs),
459
+ ConcatTupleLayer()
460
+ ]
461
+ sequence.append(cur_model)
462
+
463
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
464
+
465
+ for n in range(len(sequence)):
466
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
467
+
468
+ def get_all_activations(self, x):
469
+ res = [x]
470
+ for n in range(self.n_layers + 2):
471
+ model = getattr(self, 'model' + str(n))
472
+ res.append(model(res[-1]))
473
+ return res[1:]
474
+
475
+ def forward(self, x):
476
+ act = self.get_all_activations(x)
477
+ feats = []
478
+ for out in act[:-1]:
479
+ if isinstance(out, tuple):
480
+ if torch.is_tensor(out[1]):
481
+ out = torch.cat(out, dim=1)
482
+ else:
483
+ out = out[0]
484
+ feats.append(out)
485
+ return act[-1], feats
annotator/lama/saicinpainting/training/modules/multidilated_conv.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import random
4
+ from annotator.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
5
+
6
+ class MultidilatedConv(nn.Module):
7
+ def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True,
8
+ shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs):
9
+ super().__init__()
10
+ convs = []
11
+ self.equal_dim = equal_dim
12
+ assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode
13
+ if comb_mode in ('cat_out', 'cat_both'):
14
+ self.cat_out = True
15
+ if equal_dim:
16
+ assert out_dim % dilation_num == 0
17
+ out_dims = [out_dim // dilation_num] * dilation_num
18
+ self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], [])
19
+ else:
20
+ out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
21
+ out_dims.append(out_dim - sum(out_dims))
22
+ index = []
23
+ starts = [0] + out_dims[:-1]
24
+ lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)]
25
+ for i in range(out_dims[-1]):
26
+ for j in range(dilation_num):
27
+ index += list(range(starts[j], starts[j] + lengths[j]))
28
+ starts[j] += lengths[j]
29
+ self.index = index
30
+ assert(len(index) == out_dim)
31
+ self.out_dims = out_dims
32
+ else:
33
+ self.cat_out = False
34
+ self.out_dims = [out_dim] * dilation_num
35
+
36
+ if comb_mode in ('cat_in', 'cat_both'):
37
+ if equal_dim:
38
+ assert in_dim % dilation_num == 0
39
+ in_dims = [in_dim // dilation_num] * dilation_num
40
+ else:
41
+ in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
42
+ in_dims.append(in_dim - sum(in_dims))
43
+ self.in_dims = in_dims
44
+ self.cat_in = True
45
+ else:
46
+ self.cat_in = False
47
+ self.in_dims = [in_dim] * dilation_num
48
+
49
+ conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d
50
+ dilation = min_dilation
51
+ for i in range(dilation_num):
52
+ if isinstance(padding, int):
53
+ cur_padding = padding * dilation
54
+ else:
55
+ cur_padding = padding[i]
56
+ convs.append(conv_type(
57
+ self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs
58
+ ))
59
+ if i > 0 and shared_weights:
60
+ convs[-1].weight = convs[0].weight
61
+ convs[-1].bias = convs[0].bias
62
+ dilation *= 2
63
+ self.convs = nn.ModuleList(convs)
64
+
65
+ self.shuffle_in_channels = shuffle_in_channels
66
+ if self.shuffle_in_channels:
67
+ # shuffle list as shuffling of tensors is nondeterministic
68
+ in_channels_permute = list(range(in_dim))
69
+ random.shuffle(in_channels_permute)
70
+ # save as buffer so it is saved and loaded with checkpoint
71
+ self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute))
72
+
73
+ def forward(self, x):
74
+ if self.shuffle_in_channels:
75
+ x = x[:, self.in_channels_permute]
76
+
77
+ outs = []
78
+ if self.cat_in:
79
+ if self.equal_dim:
80
+ x = x.chunk(len(self.convs), dim=1)
81
+ else:
82
+ new_x = []
83
+ start = 0
84
+ for dim in self.in_dims:
85
+ new_x.append(x[:, start:start+dim])
86
+ start += dim
87
+ x = new_x
88
+ for i, conv in enumerate(self.convs):
89
+ if self.cat_in:
90
+ input = x[i]
91
+ else:
92
+ input = x
93
+ outs.append(conv(input))
94
+ if self.cat_out:
95
+ out = torch.cat(outs, dim=1)[:, self.index]
96
+ else:
97
+ out = sum(outs)
98
+ return out
annotator/lama/saicinpainting/training/modules/multiscale.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Union, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from annotator.lama.saicinpainting.training.modules.base import get_conv_block_ctor, get_activation
8
+ from annotator.lama.saicinpainting.training.modules.pix2pixhd import ResnetBlock
9
+
10
+
11
+ class ResNetHead(nn.Module):
12
+ def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
13
+ padding_type='reflect', conv_kind='default', activation=nn.ReLU(True)):
14
+ assert (n_blocks >= 0)
15
+ super(ResNetHead, self).__init__()
16
+
17
+ conv_layer = get_conv_block_ctor(conv_kind)
18
+
19
+ model = [nn.ReflectionPad2d(3),
20
+ conv_layer(input_nc, ngf, kernel_size=7, padding=0),
21
+ norm_layer(ngf),
22
+ activation]
23
+
24
+ ### downsample
25
+ for i in range(n_downsampling):
26
+ mult = 2 ** i
27
+ model += [conv_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
28
+ norm_layer(ngf * mult * 2),
29
+ activation]
30
+
31
+ mult = 2 ** n_downsampling
32
+
33
+ ### resnet blocks
34
+ for i in range(n_blocks):
35
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
36
+ conv_kind=conv_kind)]
37
+
38
+ self.model = nn.Sequential(*model)
39
+
40
+ def forward(self, input):
41
+ return self.model(input)
42
+
43
+
44
+ class ResNetTail(nn.Module):
45
+ def __init__(self, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
46
+ padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
47
+ up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
48
+ add_in_proj=None):
49
+ assert (n_blocks >= 0)
50
+ super(ResNetTail, self).__init__()
51
+
52
+ mult = 2 ** n_downsampling
53
+
54
+ model = []
55
+
56
+ if add_in_proj is not None:
57
+ model.append(nn.Conv2d(add_in_proj, ngf * mult, kernel_size=1))
58
+
59
+ ### resnet blocks
60
+ for i in range(n_blocks):
61
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
62
+ conv_kind=conv_kind)]
63
+
64
+ ### upsample
65
+ for i in range(n_downsampling):
66
+ mult = 2 ** (n_downsampling - i)
67
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
68
+ output_padding=1),
69
+ up_norm_layer(int(ngf * mult / 2)),
70
+ up_activation]
71
+ self.model = nn.Sequential(*model)
72
+
73
+ out_layers = []
74
+ for _ in range(out_extra_layers_n):
75
+ out_layers += [nn.Conv2d(ngf, ngf, kernel_size=1, padding=0),
76
+ up_norm_layer(ngf),
77
+ up_activation]
78
+ out_layers += [nn.ReflectionPad2d(3),
79
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
80
+
81
+ if add_out_act:
82
+ out_layers.append(get_activation('tanh' if add_out_act is True else add_out_act))
83
+
84
+ self.out_proj = nn.Sequential(*out_layers)
85
+
86
+ def forward(self, input, return_last_act=False):
87
+ features = self.model(input)
88
+ out = self.out_proj(features)
89
+ if return_last_act:
90
+ return out, features
91
+ else:
92
+ return out
93
+
94
+
95
+ class MultiscaleResNet(nn.Module):
96
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=2, n_blocks_head=2, n_blocks_tail=6, n_scales=3,
97
+ norm_layer=nn.BatchNorm2d, padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
98
+ up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
99
+ out_cumulative=False, return_only_hr=False):
100
+ super().__init__()
101
+
102
+ self.heads = nn.ModuleList([ResNetHead(input_nc, ngf=ngf, n_downsampling=n_downsampling,
103
+ n_blocks=n_blocks_head, norm_layer=norm_layer, padding_type=padding_type,
104
+ conv_kind=conv_kind, activation=activation)
105
+ for i in range(n_scales)])
106
+ tail_in_feats = ngf * (2 ** n_downsampling) + ngf
107
+ self.tails = nn.ModuleList([ResNetTail(output_nc,
108
+ ngf=ngf, n_downsampling=n_downsampling,
109
+ n_blocks=n_blocks_tail, norm_layer=norm_layer, padding_type=padding_type,
110
+ conv_kind=conv_kind, activation=activation, up_norm_layer=up_norm_layer,
111
+ up_activation=up_activation, add_out_act=add_out_act,
112
+ out_extra_layers_n=out_extra_layers_n,
113
+ add_in_proj=None if (i == n_scales - 1) else tail_in_feats)
114
+ for i in range(n_scales)])
115
+
116
+ self.out_cumulative = out_cumulative
117
+ self.return_only_hr = return_only_hr
118
+
119
+ @property
120
+ def num_scales(self):
121
+ return len(self.heads)
122
+
123
+ def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
124
+ -> Union[torch.Tensor, List[torch.Tensor]]:
125
+ """
126
+ :param ms_inputs: List of inputs of different resolutions from HR to LR
127
+ :param smallest_scales_num: int or None, number of smallest scales to take at input
128
+ :return: Depending on return_only_hr:
129
+ True: Only the most HR output
130
+ False: List of outputs of different resolutions from HR to LR
131
+ """
132
+ if smallest_scales_num is None:
133
+ assert len(self.heads) == len(ms_inputs), (len(self.heads), len(ms_inputs), smallest_scales_num)
134
+ smallest_scales_num = len(self.heads)
135
+ else:
136
+ assert smallest_scales_num == len(ms_inputs) <= len(self.heads), (len(self.heads), len(ms_inputs), smallest_scales_num)
137
+
138
+ cur_heads = self.heads[-smallest_scales_num:]
139
+ ms_features = [cur_head(cur_inp) for cur_head, cur_inp in zip(cur_heads, ms_inputs)]
140
+
141
+ all_outputs = []
142
+ prev_tail_features = None
143
+ for i in range(len(ms_features)):
144
+ scale_i = -i - 1
145
+
146
+ cur_tail_input = ms_features[-i - 1]
147
+ if prev_tail_features is not None:
148
+ if prev_tail_features.shape != cur_tail_input.shape:
149
+ prev_tail_features = F.interpolate(prev_tail_features, size=cur_tail_input.shape[2:],
150
+ mode='bilinear', align_corners=False)
151
+ cur_tail_input = torch.cat((cur_tail_input, prev_tail_features), dim=1)
152
+
153
+ cur_out, cur_tail_feats = self.tails[scale_i](cur_tail_input, return_last_act=True)
154
+
155
+ prev_tail_features = cur_tail_feats
156
+ all_outputs.append(cur_out)
157
+
158
+ if self.out_cumulative:
159
+ all_outputs_cum = [all_outputs[0]]
160
+ for i in range(1, len(ms_features)):
161
+ cur_out = all_outputs[i]
162
+ cur_out_cum = cur_out + F.interpolate(all_outputs_cum[-1], size=cur_out.shape[2:],
163
+ mode='bilinear', align_corners=False)
164
+ all_outputs_cum.append(cur_out_cum)
165
+ all_outputs = all_outputs_cum
166
+
167
+ if self.return_only_hr:
168
+ return all_outputs[-1]
169
+ else:
170
+ return all_outputs[::-1]
171
+
172
+
173
+ class MultiscaleDiscriminatorSimple(nn.Module):
174
+ def __init__(self, ms_impl):
175
+ super().__init__()
176
+ self.ms_impl = nn.ModuleList(ms_impl)
177
+
178
+ @property
179
+ def num_scales(self):
180
+ return len(self.ms_impl)
181
+
182
+ def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
183
+ -> List[Tuple[torch.Tensor, List[torch.Tensor]]]:
184
+ """
185
+ :param ms_inputs: List of inputs of different resolutions from HR to LR
186
+ :param smallest_scales_num: int or None, number of smallest scales to take at input
187
+ :return: List of pairs (prediction, features) for different resolutions from HR to LR
188
+ """
189
+ if smallest_scales_num is None:
190
+ assert len(self.ms_impl) == len(ms_inputs), (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
191
+ smallest_scales_num = len(self.heads)
192
+ else:
193
+ assert smallest_scales_num == len(ms_inputs) <= len(self.ms_impl), \
194
+ (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
195
+
196
+ return [cur_discr(cur_input) for cur_discr, cur_input in zip(self.ms_impl[-smallest_scales_num:], ms_inputs)]
197
+
198
+
199
+ class SingleToMultiScaleInputMixin:
200
+ def forward(self, x: torch.Tensor) -> List:
201
+ orig_height, orig_width = x.shape[2:]
202
+ factors = [2 ** i for i in range(self.num_scales)]
203
+ ms_inputs = [F.interpolate(x, size=(orig_height // f, orig_width // f), mode='bilinear', align_corners=False)
204
+ for f in factors]
205
+ return super().forward(ms_inputs)
206
+
207
+
208
+ class GeneratorMultiToSingleOutputMixin:
209
+ def forward(self, x):
210
+ return super().forward(x)[0]
211
+
212
+
213
+ class DiscriminatorMultiToSingleOutputMixin:
214
+ def forward(self, x):
215
+ out_feat_tuples = super().forward(x)
216
+ return out_feat_tuples[0][0], [f for _, flist in out_feat_tuples for f in flist]
217
+
218
+
219
+ class DiscriminatorMultiToSingleOutputStackedMixin:
220
+ def __init__(self, *args, return_feats_only_levels=None, **kwargs):
221
+ super().__init__(*args, **kwargs)
222
+ self.return_feats_only_levels = return_feats_only_levels
223
+
224
+ def forward(self, x):
225
+ out_feat_tuples = super().forward(x)
226
+ outs = [out for out, _ in out_feat_tuples]
227
+ scaled_outs = [outs[0]] + [F.interpolate(cur_out, size=outs[0].shape[-2:],
228
+ mode='bilinear', align_corners=False)
229
+ for cur_out in outs[1:]]
230
+ out = torch.cat(scaled_outs, dim=1)
231
+ if self.return_feats_only_levels is not None:
232
+ feat_lists = [out_feat_tuples[i][1] for i in self.return_feats_only_levels]
233
+ else:
234
+ feat_lists = [flist for _, flist in out_feat_tuples]
235
+ feats = [f for flist in feat_lists for f in flist]
236
+ return out, feats
237
+
238
+
239
+ class MultiscaleDiscrSingleInput(SingleToMultiScaleInputMixin, DiscriminatorMultiToSingleOutputStackedMixin, MultiscaleDiscriminatorSimple):
240
+ pass
241
+
242
+
243
+ class MultiscaleResNetSingle(GeneratorMultiToSingleOutputMixin, SingleToMultiScaleInputMixin, MultiscaleResNet):
244
+ pass
annotator/lama/saicinpainting/training/modules/pix2pixhd.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original: https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py
2
+ import collections
3
+ from functools import partial
4
+ import functools
5
+ import logging
6
+ from collections import defaultdict
7
+
8
+ import numpy as np
9
+ import torch.nn as nn
10
+
11
+ from annotator.lama.saicinpainting.training.modules.base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation
12
+ from annotator.lama.saicinpainting.training.modules.ffc import FFCResnetBlock
13
+ from annotator.lama.saicinpainting.training.modules.multidilated_conv import MultidilatedConv
14
+
15
+ class DotDict(defaultdict):
16
+ # https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
17
+ """dot.notation access to dictionary attributes"""
18
+ __getattr__ = defaultdict.get
19
+ __setattr__ = defaultdict.__setitem__
20
+ __delattr__ = defaultdict.__delitem__
21
+
22
+ class Identity(nn.Module):
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ def forward(self, x):
27
+ return x
28
+
29
+
30
+ class ResnetBlock(nn.Module):
31
+ def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
32
+ dilation=1, in_dim=None, groups=1, second_dilation=None):
33
+ super(ResnetBlock, self).__init__()
34
+ self.in_dim = in_dim
35
+ self.dim = dim
36
+ if second_dilation is None:
37
+ second_dilation = dilation
38
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
39
+ conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
40
+ second_dilation=second_dilation)
41
+
42
+ if self.in_dim is not None:
43
+ self.input_conv = nn.Conv2d(in_dim, dim, 1)
44
+
45
+ self.out_channnels = dim
46
+
47
+ def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
48
+ dilation=1, in_dim=None, groups=1, second_dilation=1):
49
+ conv_layer = get_conv_block_ctor(conv_kind)
50
+
51
+ conv_block = []
52
+ p = 0
53
+ if padding_type == 'reflect':
54
+ conv_block += [nn.ReflectionPad2d(dilation)]
55
+ elif padding_type == 'replicate':
56
+ conv_block += [nn.ReplicationPad2d(dilation)]
57
+ elif padding_type == 'zero':
58
+ p = dilation
59
+ else:
60
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
61
+
62
+ if in_dim is None:
63
+ in_dim = dim
64
+
65
+ conv_block += [conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation),
66
+ norm_layer(dim),
67
+ activation]
68
+ if use_dropout:
69
+ conv_block += [nn.Dropout(0.5)]
70
+
71
+ p = 0
72
+ if padding_type == 'reflect':
73
+ conv_block += [nn.ReflectionPad2d(second_dilation)]
74
+ elif padding_type == 'replicate':
75
+ conv_block += [nn.ReplicationPad2d(second_dilation)]
76
+ elif padding_type == 'zero':
77
+ p = second_dilation
78
+ else:
79
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
80
+ conv_block += [conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups),
81
+ norm_layer(dim)]
82
+
83
+ return nn.Sequential(*conv_block)
84
+
85
+ def forward(self, x):
86
+ x_before = x
87
+ if self.in_dim is not None:
88
+ x = self.input_conv(x)
89
+ out = x + self.conv_block(x_before)
90
+ return out
91
+
92
+ class ResnetBlock5x5(nn.Module):
93
+ def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
94
+ dilation=1, in_dim=None, groups=1, second_dilation=None):
95
+ super(ResnetBlock5x5, self).__init__()
96
+ self.in_dim = in_dim
97
+ self.dim = dim
98
+ if second_dilation is None:
99
+ second_dilation = dilation
100
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
101
+ conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
102
+ second_dilation=second_dilation)
103
+
104
+ if self.in_dim is not None:
105
+ self.input_conv = nn.Conv2d(in_dim, dim, 1)
106
+
107
+ self.out_channnels = dim
108
+
109
+ def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
110
+ dilation=1, in_dim=None, groups=1, second_dilation=1):
111
+ conv_layer = get_conv_block_ctor(conv_kind)
112
+
113
+ conv_block = []
114
+ p = 0
115
+ if padding_type == 'reflect':
116
+ conv_block += [nn.ReflectionPad2d(dilation * 2)]
117
+ elif padding_type == 'replicate':
118
+ conv_block += [nn.ReplicationPad2d(dilation * 2)]
119
+ elif padding_type == 'zero':
120
+ p = dilation * 2
121
+ else:
122
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
123
+
124
+ if in_dim is None:
125
+ in_dim = dim
126
+
127
+ conv_block += [conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation),
128
+ norm_layer(dim),
129
+ activation]
130
+ if use_dropout:
131
+ conv_block += [nn.Dropout(0.5)]
132
+
133
+ p = 0
134
+ if padding_type == 'reflect':
135
+ conv_block += [nn.ReflectionPad2d(second_dilation * 2)]
136
+ elif padding_type == 'replicate':
137
+ conv_block += [nn.ReplicationPad2d(second_dilation * 2)]
138
+ elif padding_type == 'zero':
139
+ p = second_dilation * 2
140
+ else:
141
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
142
+ conv_block += [conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups),
143
+ norm_layer(dim)]
144
+
145
+ return nn.Sequential(*conv_block)
146
+
147
+ def forward(self, x):
148
+ x_before = x
149
+ if self.in_dim is not None:
150
+ x = self.input_conv(x)
151
+ out = x + self.conv_block(x_before)
152
+ return out
153
+
154
+
155
+ class MultidilatedResnetBlock(nn.Module):
156
+ def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False):
157
+ super().__init__()
158
+ self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout)
159
+
160
+ def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1):
161
+ conv_block = []
162
+ conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
163
+ norm_layer(dim),
164
+ activation]
165
+ if use_dropout:
166
+ conv_block += [nn.Dropout(0.5)]
167
+
168
+ conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
169
+ norm_layer(dim)]
170
+
171
+ return nn.Sequential(*conv_block)
172
+
173
+ def forward(self, x):
174
+ out = x + self.conv_block(x)
175
+ return out
176
+
177
+
178
+ class MultiDilatedGlobalGenerator(nn.Module):
179
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
180
+ n_blocks=3, norm_layer=nn.BatchNorm2d,
181
+ padding_type='reflect', conv_kind='default',
182
+ deconv_kind='convtranspose', activation=nn.ReLU(True),
183
+ up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
184
+ add_out_act=True, max_features=1024, multidilation_kwargs={},
185
+ ffc_positions=None, ffc_kwargs={}):
186
+ assert (n_blocks >= 0)
187
+ super().__init__()
188
+
189
+ conv_layer = get_conv_block_ctor(conv_kind)
190
+ resnet_conv_layer = functools.partial(get_conv_block_ctor('multidilated'), **multidilation_kwargs)
191
+ norm_layer = get_norm_layer(norm_layer)
192
+ if affine is not None:
193
+ norm_layer = partial(norm_layer, affine=affine)
194
+ up_norm_layer = get_norm_layer(up_norm_layer)
195
+ if affine is not None:
196
+ up_norm_layer = partial(up_norm_layer, affine=affine)
197
+
198
+ model = [nn.ReflectionPad2d(3),
199
+ conv_layer(input_nc, ngf, kernel_size=7, padding=0),
200
+ norm_layer(ngf),
201
+ activation]
202
+
203
+ identity = Identity()
204
+ ### downsample
205
+ for i in range(n_downsampling):
206
+ mult = 2 ** i
207
+
208
+ model += [conv_layer(min(max_features, ngf * mult),
209
+ min(max_features, ngf * mult * 2),
210
+ kernel_size=3, stride=2, padding=1),
211
+ norm_layer(min(max_features, ngf * mult * 2)),
212
+ activation]
213
+
214
+ mult = 2 ** n_downsampling
215
+ feats_num_bottleneck = min(max_features, ngf * mult)
216
+
217
+ ### resnet blocks
218
+ for i in range(n_blocks):
219
+ if ffc_positions is not None and i in ffc_positions:
220
+ model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
221
+ inline=True, **ffc_kwargs)]
222
+ model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
223
+ conv_layer=resnet_conv_layer, activation=activation,
224
+ norm_layer=norm_layer)]
225
+
226
+ ### upsample
227
+ for i in range(n_downsampling):
228
+ mult = 2 ** (n_downsampling - i)
229
+ model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
230
+ model += [nn.ReflectionPad2d(3),
231
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
232
+ if add_out_act:
233
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
234
+ self.model = nn.Sequential(*model)
235
+
236
+ def forward(self, input):
237
+ return self.model(input)
238
+
239
+ class ConfigGlobalGenerator(nn.Module):
240
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
241
+ n_blocks=3, norm_layer=nn.BatchNorm2d,
242
+ padding_type='reflect', conv_kind='default',
243
+ deconv_kind='convtranspose', activation=nn.ReLU(True),
244
+ up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
245
+ add_out_act=True, max_features=1024,
246
+ manual_block_spec=[],
247
+ resnet_block_kind='multidilatedresnetblock',
248
+ resnet_conv_kind='multidilated',
249
+ resnet_dilation=1,
250
+ multidilation_kwargs={}):
251
+ assert (n_blocks >= 0)
252
+ super().__init__()
253
+
254
+ conv_layer = get_conv_block_ctor(conv_kind)
255
+ resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs)
256
+ norm_layer = get_norm_layer(norm_layer)
257
+ if affine is not None:
258
+ norm_layer = partial(norm_layer, affine=affine)
259
+ up_norm_layer = get_norm_layer(up_norm_layer)
260
+ if affine is not None:
261
+ up_norm_layer = partial(up_norm_layer, affine=affine)
262
+
263
+ model = [nn.ReflectionPad2d(3),
264
+ conv_layer(input_nc, ngf, kernel_size=7, padding=0),
265
+ norm_layer(ngf),
266
+ activation]
267
+
268
+ identity = Identity()
269
+
270
+ ### downsample
271
+ for i in range(n_downsampling):
272
+ mult = 2 ** i
273
+ model += [conv_layer(min(max_features, ngf * mult),
274
+ min(max_features, ngf * mult * 2),
275
+ kernel_size=3, stride=2, padding=1),
276
+ norm_layer(min(max_features, ngf * mult * 2)),
277
+ activation]
278
+
279
+ mult = 2 ** n_downsampling
280
+ feats_num_bottleneck = min(max_features, ngf * mult)
281
+
282
+ if len(manual_block_spec) == 0:
283
+ manual_block_spec = [
284
+ DotDict(lambda : None, {
285
+ 'n_blocks': n_blocks,
286
+ 'use_default': True})
287
+ ]
288
+
289
+ ### resnet blocks
290
+ for block_spec in manual_block_spec:
291
+ def make_and_add_blocks(model, block_spec):
292
+ block_spec = DotDict(lambda : None, block_spec)
293
+ if not block_spec.use_default:
294
+ resnet_conv_layer = functools.partial(get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs)
295
+ resnet_conv_kind = block_spec.resnet_conv_kind
296
+ resnet_block_kind = block_spec.resnet_block_kind
297
+ if block_spec.resnet_dilation is not None:
298
+ resnet_dilation = block_spec.resnet_dilation
299
+ for i in range(block_spec.n_blocks):
300
+ if resnet_block_kind == "multidilatedresnetblock":
301
+ model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
302
+ conv_layer=resnet_conv_layer, activation=activation,
303
+ norm_layer=norm_layer)]
304
+ if resnet_block_kind == "resnetblock":
305
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
306
+ conv_kind=resnet_conv_kind)]
307
+ if resnet_block_kind == "resnetblock5x5":
308
+ model += [ResnetBlock5x5(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
309
+ conv_kind=resnet_conv_kind)]
310
+ if resnet_block_kind == "resnetblockdwdil":
311
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
312
+ conv_kind=resnet_conv_kind, dilation=resnet_dilation, second_dilation=resnet_dilation)]
313
+ make_and_add_blocks(model, block_spec)
314
+
315
+ ### upsample
316
+ for i in range(n_downsampling):
317
+ mult = 2 ** (n_downsampling - i)
318
+ model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
319
+ model += [nn.ReflectionPad2d(3),
320
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
321
+ if add_out_act:
322
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
323
+ self.model = nn.Sequential(*model)
324
+
325
+ def forward(self, input):
326
+ return self.model(input)
327
+
328
+
329
+ def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs):
330
+ blocks = []
331
+ for i in range(dilated_blocks_n):
332
+ if dilation_block_kind == 'simple':
333
+ blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1)))
334
+ elif dilation_block_kind == 'multi':
335
+ blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs))
336
+ else:
337
+ raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"')
338
+ return blocks
339
+
340
+
341
+ class GlobalGenerator(nn.Module):
342
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
343
+ padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
344
+ up_norm_layer=nn.BatchNorm2d, affine=None,
345
+ up_activation=nn.ReLU(True), dilated_blocks_n=0, dilated_blocks_n_start=0,
346
+ dilated_blocks_n_middle=0,
347
+ add_out_act=True,
348
+ max_features=1024, is_resblock_depthwise=False,
349
+ ffc_positions=None, ffc_kwargs={}, dilation=1, second_dilation=None,
350
+ dilation_block_kind='simple', multidilation_kwargs={}):
351
+ assert (n_blocks >= 0)
352
+ super().__init__()
353
+
354
+ conv_layer = get_conv_block_ctor(conv_kind)
355
+ norm_layer = get_norm_layer(norm_layer)
356
+ if affine is not None:
357
+ norm_layer = partial(norm_layer, affine=affine)
358
+ up_norm_layer = get_norm_layer(up_norm_layer)
359
+ if affine is not None:
360
+ up_norm_layer = partial(up_norm_layer, affine=affine)
361
+
362
+ if ffc_positions is not None:
363
+ ffc_positions = collections.Counter(ffc_positions)
364
+
365
+ model = [nn.ReflectionPad2d(3),
366
+ conv_layer(input_nc, ngf, kernel_size=7, padding=0),
367
+ norm_layer(ngf),
368
+ activation]
369
+
370
+ identity = Identity()
371
+ ### downsample
372
+ for i in range(n_downsampling):
373
+ mult = 2 ** i
374
+
375
+ model += [conv_layer(min(max_features, ngf * mult),
376
+ min(max_features, ngf * mult * 2),
377
+ kernel_size=3, stride=2, padding=1),
378
+ norm_layer(min(max_features, ngf * mult * 2)),
379
+ activation]
380
+
381
+ mult = 2 ** n_downsampling
382
+ feats_num_bottleneck = min(max_features, ngf * mult)
383
+
384
+ dilated_block_kwargs = dict(dim=feats_num_bottleneck, padding_type=padding_type,
385
+ activation=activation, norm_layer=norm_layer)
386
+ if dilation_block_kind == 'simple':
387
+ dilated_block_kwargs['conv_kind'] = conv_kind
388
+ elif dilation_block_kind == 'multi':
389
+ dilated_block_kwargs['conv_layer'] = functools.partial(
390
+ get_conv_block_ctor('multidilated'), **multidilation_kwargs)
391
+
392
+ # dilated blocks at the start of the bottleneck sausage
393
+ if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0:
394
+ model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs)
395
+
396
+ # resnet blocks
397
+ for i in range(n_blocks):
398
+ # dilated blocks at the middle of the bottleneck sausage
399
+ if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0:
400
+ model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs)
401
+
402
+ if ffc_positions is not None and i in ffc_positions:
403
+ for _ in range(ffc_positions[i]): # same position can occur more than once
404
+ model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
405
+ inline=True, **ffc_kwargs)]
406
+
407
+ if is_resblock_depthwise:
408
+ resblock_groups = feats_num_bottleneck
409
+ else:
410
+ resblock_groups = 1
411
+
412
+ model += [ResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation=activation,
413
+ norm_layer=norm_layer, conv_kind=conv_kind, groups=resblock_groups,
414
+ dilation=dilation, second_dilation=second_dilation)]
415
+
416
+
417
+ # dilated blocks at the end of the bottleneck sausage
418
+ if dilated_blocks_n is not None and dilated_blocks_n > 0:
419
+ model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs)
420
+
421
+ # upsample
422
+ for i in range(n_downsampling):
423
+ mult = 2 ** (n_downsampling - i)
424
+ model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
425
+ min(max_features, int(ngf * mult / 2)),
426
+ kernel_size=3, stride=2, padding=1, output_padding=1),
427
+ up_norm_layer(min(max_features, int(ngf * mult / 2))),
428
+ up_activation]
429
+ model += [nn.ReflectionPad2d(3),
430
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
431
+ if add_out_act:
432
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
433
+ self.model = nn.Sequential(*model)
434
+
435
+ def forward(self, input):
436
+ return self.model(input)
437
+
438
+
439
+ class GlobalGeneratorGated(GlobalGenerator):
440
+ def __init__(self, *args, **kwargs):
441
+ real_kwargs=dict(
442
+ conv_kind='gated_bn_relu',
443
+ activation=nn.Identity(),
444
+ norm_layer=nn.Identity
445
+ )
446
+ real_kwargs.update(kwargs)
447
+ super().__init__(*args, **real_kwargs)
448
+
449
+
450
+ class GlobalGeneratorFromSuperChannels(nn.Module):
451
+ def __init__(self, input_nc, output_nc, n_downsampling, n_blocks, super_channels, norm_layer="bn", padding_type='reflect', add_out_act=True):
452
+ super().__init__()
453
+ self.n_downsampling = n_downsampling
454
+ norm_layer = get_norm_layer(norm_layer)
455
+ if type(norm_layer) == functools.partial:
456
+ use_bias = (norm_layer.func == nn.InstanceNorm2d)
457
+ else:
458
+ use_bias = (norm_layer == nn.InstanceNorm2d)
459
+
460
+ channels = self.convert_super_channels(super_channels)
461
+ self.channels = channels
462
+
463
+ model = [nn.ReflectionPad2d(3),
464
+ nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias),
465
+ norm_layer(channels[0]),
466
+ nn.ReLU(True)]
467
+
468
+ for i in range(n_downsampling): # add downsampling layers
469
+ mult = 2 ** i
470
+ model += [nn.Conv2d(channels[0+i], channels[1+i], kernel_size=3, stride=2, padding=1, bias=use_bias),
471
+ norm_layer(channels[1+i]),
472
+ nn.ReLU(True)]
473
+
474
+ mult = 2 ** n_downsampling
475
+
476
+ n_blocks1 = n_blocks // 3
477
+ n_blocks2 = n_blocks1
478
+ n_blocks3 = n_blocks - n_blocks1 - n_blocks2
479
+
480
+ for i in range(n_blocks1):
481
+ c = n_downsampling
482
+ dim = channels[c]
483
+ model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)]
484
+
485
+ for i in range(n_blocks2):
486
+ c = n_downsampling+1
487
+ dim = channels[c]
488
+ kwargs = {}
489
+ if i == 0:
490
+ kwargs = {"in_dim": channels[c-1]}
491
+ model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
492
+
493
+ for i in range(n_blocks3):
494
+ c = n_downsampling+2
495
+ dim = channels[c]
496
+ kwargs = {}
497
+ if i == 0:
498
+ kwargs = {"in_dim": channels[c-1]}
499
+ model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
500
+
501
+ for i in range(n_downsampling): # add upsampling layers
502
+ mult = 2 ** (n_downsampling - i)
503
+ model += [nn.ConvTranspose2d(channels[n_downsampling+3+i],
504
+ channels[n_downsampling+3+i+1],
505
+ kernel_size=3, stride=2,
506
+ padding=1, output_padding=1,
507
+ bias=use_bias),
508
+ norm_layer(channels[n_downsampling+3+i+1]),
509
+ nn.ReLU(True)]
510
+ model += [nn.ReflectionPad2d(3)]
511
+ model += [nn.Conv2d(channels[2*n_downsampling+3], output_nc, kernel_size=7, padding=0)]
512
+
513
+ if add_out_act:
514
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
515
+ self.model = nn.Sequential(*model)
516
+
517
+ def convert_super_channels(self, super_channels):
518
+ n_downsampling = self.n_downsampling
519
+ result = []
520
+ cnt = 0
521
+
522
+ if n_downsampling == 2:
523
+ N1 = 10
524
+ elif n_downsampling == 3:
525
+ N1 = 13
526
+ else:
527
+ raise NotImplementedError
528
+
529
+ for i in range(0, N1):
530
+ if i in [1,4,7,10]:
531
+ channel = super_channels[cnt] * (2 ** cnt)
532
+ config = {'channel': channel}
533
+ result.append(channel)
534
+ logging.info(f"Downsample channels {result[-1]}")
535
+ cnt += 1
536
+
537
+ for i in range(3):
538
+ for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)):
539
+ if len(super_channels) == 6:
540
+ channel = super_channels[3] * 4
541
+ else:
542
+ channel = super_channels[i + 3] * 4
543
+ config = {'channel': channel}
544
+ if counter == 0:
545
+ result.append(channel)
546
+ logging.info(f"Bottleneck channels {result[-1]}")
547
+ cnt = 2
548
+
549
+ for i in range(N1+9, N1+21):
550
+ if i in [22, 25,28]:
551
+ cnt -= 1
552
+ if len(super_channels) == 6:
553
+ channel = super_channels[5 - cnt] * (2 ** cnt)
554
+ else:
555
+ channel = super_channels[7 - cnt] * (2 ** cnt)
556
+ result.append(int(channel))
557
+ logging.info(f"Upsample channels {result[-1]}")
558
+ return result
559
+
560
+ def forward(self, input):
561
+ return self.model(input)
562
+
563
+
564
+ # Defines the PatchGAN discriminator with the specified arguments.
565
+ class NLayerDiscriminator(BaseDiscriminator):
566
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,):
567
+ super().__init__()
568
+ self.n_layers = n_layers
569
+
570
+ kw = 4
571
+ padw = int(np.ceil((kw-1.0)/2))
572
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
573
+ nn.LeakyReLU(0.2, True)]]
574
+
575
+ nf = ndf
576
+ for n in range(1, n_layers):
577
+ nf_prev = nf
578
+ nf = min(nf * 2, 512)
579
+
580
+ cur_model = []
581
+ cur_model += [
582
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
583
+ norm_layer(nf),
584
+ nn.LeakyReLU(0.2, True)
585
+ ]
586
+ sequence.append(cur_model)
587
+
588
+ nf_prev = nf
589
+ nf = min(nf * 2, 512)
590
+
591
+ cur_model = []
592
+ cur_model += [
593
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
594
+ norm_layer(nf),
595
+ nn.LeakyReLU(0.2, True)
596
+ ]
597
+ sequence.append(cur_model)
598
+
599
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
600
+
601
+ for n in range(len(sequence)):
602
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
603
+
604
+ def get_all_activations(self, x):
605
+ res = [x]
606
+ for n in range(self.n_layers + 2):
607
+ model = getattr(self, 'model' + str(n))
608
+ res.append(model(res[-1]))
609
+ return res[1:]
610
+
611
+ def forward(self, x):
612
+ act = self.get_all_activations(x)
613
+ return act[-1], act[:-1]
614
+
615
+
616
+ class MultidilatedNLayerDiscriminator(BaseDiscriminator):
617
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}):
618
+ super().__init__()
619
+ self.n_layers = n_layers
620
+
621
+ kw = 4
622
+ padw = int(np.ceil((kw-1.0)/2))
623
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
624
+ nn.LeakyReLU(0.2, True)]]
625
+
626
+ nf = ndf
627
+ for n in range(1, n_layers):
628
+ nf_prev = nf
629
+ nf = min(nf * 2, 512)
630
+
631
+ cur_model = []
632
+ cur_model += [
633
+ MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs),
634
+ norm_layer(nf),
635
+ nn.LeakyReLU(0.2, True)
636
+ ]
637
+ sequence.append(cur_model)
638
+
639
+ nf_prev = nf
640
+ nf = min(nf * 2, 512)
641
+
642
+ cur_model = []
643
+ cur_model += [
644
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
645
+ norm_layer(nf),
646
+ nn.LeakyReLU(0.2, True)
647
+ ]
648
+ sequence.append(cur_model)
649
+
650
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
651
+
652
+ for n in range(len(sequence)):
653
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
654
+
655
+ def get_all_activations(self, x):
656
+ res = [x]
657
+ for n in range(self.n_layers + 2):
658
+ model = getattr(self, 'model' + str(n))
659
+ res.append(model(res[-1]))
660
+ return res[1:]
661
+
662
+ def forward(self, x):
663
+ act = self.get_all_activations(x)
664
+ return act[-1], act[:-1]
665
+
666
+
667
+ class NLayerDiscriminatorAsGen(NLayerDiscriminator):
668
+ def forward(self, x):
669
+ return super().forward(x)[0]
annotator/lama/saicinpainting/training/modules/spatial_transform.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from kornia.geometry.transform import rotate
5
+
6
+
7
+ class LearnableSpatialTransformWrapper(nn.Module):
8
+ def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
9
+ super().__init__()
10
+ self.impl = impl
11
+ self.angle = torch.rand(1) * angle_init_range
12
+ if train_angle:
13
+ self.angle = nn.Parameter(self.angle, requires_grad=True)
14
+ self.pad_coef = pad_coef
15
+
16
+ def forward(self, x):
17
+ if torch.is_tensor(x):
18
+ return self.inverse_transform(self.impl(self.transform(x)), x)
19
+ elif isinstance(x, tuple):
20
+ x_trans = tuple(self.transform(elem) for elem in x)
21
+ y_trans = self.impl(x_trans)
22
+ return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x))
23
+ else:
24
+ raise ValueError(f'Unexpected input type {type(x)}')
25
+
26
+ def transform(self, x):
27
+ height, width = x.shape[2:]
28
+ pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
29
+ x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect')
30
+ x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded))
31
+ return x_padded_rotated
32
+
33
+ def inverse_transform(self, y_padded_rotated, orig_x):
34
+ height, width = orig_x.shape[2:]
35
+ pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
36
+
37
+ y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated))
38
+ y_height, y_width = y_padded.shape[2:]
39
+ y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
40
+ return y
41
+
42
+
43
+ if __name__ == '__main__':
44
+ layer = LearnableSpatialTransformWrapper(nn.Identity())
45
+ x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float()
46
+ y = layer(x)
47
+ assert x.shape == y.shape
48
+ assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1])
49
+ print('all ok')
annotator/lama/saicinpainting/training/modules/squeeze_excitation.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class SELayer(nn.Module):
5
+ def __init__(self, channel, reduction=16):
6
+ super(SELayer, self).__init__()
7
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
8
+ self.fc = nn.Sequential(
9
+ nn.Linear(channel, channel // reduction, bias=False),
10
+ nn.ReLU(inplace=True),
11
+ nn.Linear(channel // reduction, channel, bias=False),
12
+ nn.Sigmoid()
13
+ )
14
+
15
+ def forward(self, x):
16
+ b, c, _, _ = x.size()
17
+ y = self.avg_pool(x).view(b, c)
18
+ y = self.fc(y).view(b, c, 1, 1)
19
+ res = x * y.expand_as(x)
20
+ return res
annotator/lama/saicinpainting/training/trainers/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ from annotator.lama.saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule
4
+
5
+
6
+ def get_training_model_class(kind):
7
+ if kind == 'default':
8
+ return DefaultInpaintingTrainingModule
9
+
10
+ raise ValueError(f'Unknown trainer module {kind}')
11
+
12
+
13
+ def make_training_model(config):
14
+ kind = config.training_model.kind
15
+ kwargs = dict(config.training_model)
16
+ kwargs.pop('kind')
17
+ kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp'
18
+
19
+ logging.info(f'Make training model {kind}')
20
+
21
+ cls = get_training_model_class(kind)
22
+ return cls(config, **kwargs)
23
+
24
+
25
+ def load_checkpoint(train_config, path, map_location='cuda', strict=True):
26
+ model = make_training_model(train_config).generator
27
+ state = torch.load(path, map_location=map_location)
28
+ model.load_state_dict(state, strict=strict)
29
+ return model
annotator/lama/saicinpainting/training/trainers/base.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from typing import Dict, Tuple
4
+
5
+ import pandas as pd
6
+ import pytorch_lightning as ptl
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ # from torch.utils.data import DistributedSampler
11
+
12
+ # from annotator.lama.saicinpainting.evaluation import make_evaluator
13
+ # from annotator.lama.saicinpainting.training.data.datasets import make_default_train_dataloader, make_default_val_dataloader
14
+ # from annotator.lama.saicinpainting.training.losses.adversarial import make_discrim_loss
15
+ # from annotator.lama.saicinpainting.training.losses.perceptual import PerceptualLoss, ResNetPL
16
+ from annotator.lama.saicinpainting.training.modules import make_generator #, make_discriminator
17
+ # from annotator.lama.saicinpainting.training.visualizers import make_visualizer
18
+ from annotator.lama.saicinpainting.utils import add_prefix_to_keys, average_dicts, set_requires_grad, flatten_dict, \
19
+ get_has_ddp_rank
20
+
21
+ LOGGER = logging.getLogger(__name__)
22
+
23
+
24
+ def make_optimizer(parameters, kind='adamw', **kwargs):
25
+ if kind == 'adam':
26
+ optimizer_class = torch.optim.Adam
27
+ elif kind == 'adamw':
28
+ optimizer_class = torch.optim.AdamW
29
+ else:
30
+ raise ValueError(f'Unknown optimizer kind {kind}')
31
+ return optimizer_class(parameters, **kwargs)
32
+
33
+
34
+ def update_running_average(result: nn.Module, new_iterate_model: nn.Module, decay=0.999):
35
+ with torch.no_grad():
36
+ res_params = dict(result.named_parameters())
37
+ new_params = dict(new_iterate_model.named_parameters())
38
+
39
+ for k in res_params.keys():
40
+ res_params[k].data.mul_(decay).add_(new_params[k].data, alpha=1 - decay)
41
+
42
+
43
+ def make_multiscale_noise(base_tensor, scales=6, scale_mode='bilinear'):
44
+ batch_size, _, height, width = base_tensor.shape
45
+ cur_height, cur_width = height, width
46
+ result = []
47
+ align_corners = False if scale_mode in ('bilinear', 'bicubic') else None
48
+ for _ in range(scales):
49
+ cur_sample = torch.randn(batch_size, 1, cur_height, cur_width, device=base_tensor.device)
50
+ cur_sample_scaled = F.interpolate(cur_sample, size=(height, width), mode=scale_mode, align_corners=align_corners)
51
+ result.append(cur_sample_scaled)
52
+ cur_height //= 2
53
+ cur_width //= 2
54
+ return torch.cat(result, dim=1)
55
+
56
+
57
+ class BaseInpaintingTrainingModule(ptl.LightningModule):
58
+ def __init__(self, config, use_ddp, *args, predict_only=False, visualize_each_iters=100,
59
+ average_generator=False, generator_avg_beta=0.999, average_generator_start_step=30000,
60
+ average_generator_period=10, store_discr_outputs_for_vis=False,
61
+ **kwargs):
62
+ super().__init__(*args, **kwargs)
63
+ LOGGER.info('BaseInpaintingTrainingModule init called')
64
+
65
+ self.config = config
66
+
67
+ self.generator = make_generator(config, **self.config.generator)
68
+ self.use_ddp = use_ddp
69
+
70
+ if not get_has_ddp_rank():
71
+ LOGGER.info(f'Generator\n{self.generator}')
72
+
73
+ # if not predict_only:
74
+ # self.save_hyperparameters(self.config)
75
+ # self.discriminator = make_discriminator(**self.config.discriminator)
76
+ # self.adversarial_loss = make_discrim_loss(**self.config.losses.adversarial)
77
+ # self.visualizer = make_visualizer(**self.config.visualizer)
78
+ # self.val_evaluator = make_evaluator(**self.config.evaluator)
79
+ # self.test_evaluator = make_evaluator(**self.config.evaluator)
80
+ #
81
+ # if not get_has_ddp_rank():
82
+ # LOGGER.info(f'Discriminator\n{self.discriminator}')
83
+ #
84
+ # extra_val = self.config.data.get('extra_val', ())
85
+ # if extra_val:
86
+ # self.extra_val_titles = list(extra_val)
87
+ # self.extra_evaluators = nn.ModuleDict({k: make_evaluator(**self.config.evaluator)
88
+ # for k in extra_val})
89
+ # else:
90
+ # self.extra_evaluators = {}
91
+ #
92
+ # self.average_generator = average_generator
93
+ # self.generator_avg_beta = generator_avg_beta
94
+ # self.average_generator_start_step = average_generator_start_step
95
+ # self.average_generator_period = average_generator_period
96
+ # self.generator_average = None
97
+ # self.last_generator_averaging_step = -1
98
+ # self.store_discr_outputs_for_vis = store_discr_outputs_for_vis
99
+ #
100
+ # if self.config.losses.get("l1", {"weight_known": 0})['weight_known'] > 0:
101
+ # self.loss_l1 = nn.L1Loss(reduction='none')
102
+ #
103
+ # if self.config.losses.get("mse", {"weight": 0})['weight'] > 0:
104
+ # self.loss_mse = nn.MSELoss(reduction='none')
105
+ #
106
+ # if self.config.losses.perceptual.weight > 0:
107
+ # self.loss_pl = PerceptualLoss()
108
+ #
109
+ # # if self.config.losses.get("resnet_pl", {"weight": 0})['weight'] > 0:
110
+ # # self.loss_resnet_pl = ResNetPL(**self.config.losses.resnet_pl)
111
+ # # else:
112
+ # # self.loss_resnet_pl = None
113
+ #
114
+ # self.loss_resnet_pl = None
115
+
116
+ self.visualize_each_iters = visualize_each_iters
117
+ LOGGER.info('BaseInpaintingTrainingModule init done')
118
+
119
+ def configure_optimizers(self):
120
+ discriminator_params = list(self.discriminator.parameters())
121
+ return [
122
+ dict(optimizer=make_optimizer(self.generator.parameters(), **self.config.optimizers.generator)),
123
+ dict(optimizer=make_optimizer(discriminator_params, **self.config.optimizers.discriminator)),
124
+ ]
125
+
126
+ def train_dataloader(self):
127
+ kwargs = dict(self.config.data.train)
128
+ if self.use_ddp:
129
+ kwargs['ddp_kwargs'] = dict(num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
130
+ rank=self.trainer.global_rank,
131
+ shuffle=True)
132
+ dataloader = make_default_train_dataloader(**self.config.data.train)
133
+ return dataloader
134
+
135
+ def val_dataloader(self):
136
+ res = [make_default_val_dataloader(**self.config.data.val)]
137
+
138
+ if self.config.data.visual_test is not None:
139
+ res = res + [make_default_val_dataloader(**self.config.data.visual_test)]
140
+ else:
141
+ res = res + res
142
+
143
+ extra_val = self.config.data.get('extra_val', ())
144
+ if extra_val:
145
+ res += [make_default_val_dataloader(**extra_val[k]) for k in self.extra_val_titles]
146
+
147
+ return res
148
+
149
+ def training_step(self, batch, batch_idx, optimizer_idx=None):
150
+ self._is_training_step = True
151
+ return self._do_step(batch, batch_idx, mode='train', optimizer_idx=optimizer_idx)
152
+
153
+ def validation_step(self, batch, batch_idx, dataloader_idx):
154
+ extra_val_key = None
155
+ if dataloader_idx == 0:
156
+ mode = 'val'
157
+ elif dataloader_idx == 1:
158
+ mode = 'test'
159
+ else:
160
+ mode = 'extra_val'
161
+ extra_val_key = self.extra_val_titles[dataloader_idx - 2]
162
+ self._is_training_step = False
163
+ return self._do_step(batch, batch_idx, mode=mode, extra_val_key=extra_val_key)
164
+
165
+ def training_step_end(self, batch_parts_outputs):
166
+ if self.training and self.average_generator \
167
+ and self.global_step >= self.average_generator_start_step \
168
+ and self.global_step >= self.last_generator_averaging_step + self.average_generator_period:
169
+ if self.generator_average is None:
170
+ self.generator_average = copy.deepcopy(self.generator)
171
+ else:
172
+ update_running_average(self.generator_average, self.generator, decay=self.generator_avg_beta)
173
+ self.last_generator_averaging_step = self.global_step
174
+
175
+ full_loss = (batch_parts_outputs['loss'].mean()
176
+ if torch.is_tensor(batch_parts_outputs['loss']) # loss is not tensor when no discriminator used
177
+ else torch.tensor(batch_parts_outputs['loss']).float().requires_grad_(True))
178
+ log_info = {k: v.mean() for k, v in batch_parts_outputs['log_info'].items()}
179
+ self.log_dict(log_info, on_step=True, on_epoch=False)
180
+ return full_loss
181
+
182
+ def validation_epoch_end(self, outputs):
183
+ outputs = [step_out for out_group in outputs for step_out in out_group]
184
+ averaged_logs = average_dicts(step_out['log_info'] for step_out in outputs)
185
+ self.log_dict({k: v.mean() for k, v in averaged_logs.items()})
186
+
187
+ pd.set_option('display.max_columns', 500)
188
+ pd.set_option('display.width', 1000)
189
+
190
+ # standard validation
191
+ val_evaluator_states = [s['val_evaluator_state'] for s in outputs if 'val_evaluator_state' in s]
192
+ val_evaluator_res = self.val_evaluator.evaluation_end(states=val_evaluator_states)
193
+ val_evaluator_res_df = pd.DataFrame(val_evaluator_res).stack(1).unstack(0)
194
+ val_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
195
+ LOGGER.info(f'Validation metrics after epoch #{self.current_epoch}, '
196
+ f'total {self.global_step} iterations:\n{val_evaluator_res_df}')
197
+
198
+ for k, v in flatten_dict(val_evaluator_res).items():
199
+ self.log(f'val_{k}', v)
200
+
201
+ # standard visual test
202
+ test_evaluator_states = [s['test_evaluator_state'] for s in outputs
203
+ if 'test_evaluator_state' in s]
204
+ test_evaluator_res = self.test_evaluator.evaluation_end(states=test_evaluator_states)
205
+ test_evaluator_res_df = pd.DataFrame(test_evaluator_res).stack(1).unstack(0)
206
+ test_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
207
+ LOGGER.info(f'Test metrics after epoch #{self.current_epoch}, '
208
+ f'total {self.global_step} iterations:\n{test_evaluator_res_df}')
209
+
210
+ for k, v in flatten_dict(test_evaluator_res).items():
211
+ self.log(f'test_{k}', v)
212
+
213
+ # extra validations
214
+ if self.extra_evaluators:
215
+ for cur_eval_title, cur_evaluator in self.extra_evaluators.items():
216
+ cur_state_key = f'extra_val_{cur_eval_title}_evaluator_state'
217
+ cur_states = [s[cur_state_key] for s in outputs if cur_state_key in s]
218
+ cur_evaluator_res = cur_evaluator.evaluation_end(states=cur_states)
219
+ cur_evaluator_res_df = pd.DataFrame(cur_evaluator_res).stack(1).unstack(0)
220
+ cur_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
221
+ LOGGER.info(f'Extra val {cur_eval_title} metrics after epoch #{self.current_epoch}, '
222
+ f'total {self.global_step} iterations:\n{cur_evaluator_res_df}')
223
+ for k, v in flatten_dict(cur_evaluator_res).items():
224
+ self.log(f'extra_val_{cur_eval_title}_{k}', v)
225
+
226
+ def _do_step(self, batch, batch_idx, mode='train', optimizer_idx=None, extra_val_key=None):
227
+ if optimizer_idx == 0: # step for generator
228
+ set_requires_grad(self.generator, True)
229
+ set_requires_grad(self.discriminator, False)
230
+ elif optimizer_idx == 1: # step for discriminator
231
+ set_requires_grad(self.generator, False)
232
+ set_requires_grad(self.discriminator, True)
233
+
234
+ batch = self(batch)
235
+
236
+ total_loss = 0
237
+ metrics = {}
238
+
239
+ if optimizer_idx is None or optimizer_idx == 0: # step for generator
240
+ total_loss, metrics = self.generator_loss(batch)
241
+
242
+ elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator
243
+ if self.config.losses.adversarial.weight > 0:
244
+ total_loss, metrics = self.discriminator_loss(batch)
245
+
246
+ if self.get_ddp_rank() in (None, 0) and (batch_idx % self.visualize_each_iters == 0 or mode == 'test'):
247
+ if self.config.losses.adversarial.weight > 0:
248
+ if self.store_discr_outputs_for_vis:
249
+ with torch.no_grad():
250
+ self.store_discr_outputs(batch)
251
+ vis_suffix = f'_{mode}'
252
+ if mode == 'extra_val':
253
+ vis_suffix += f'_{extra_val_key}'
254
+ self.visualizer(self.current_epoch, batch_idx, batch, suffix=vis_suffix)
255
+
256
+ metrics_prefix = f'{mode}_'
257
+ if mode == 'extra_val':
258
+ metrics_prefix += f'{extra_val_key}_'
259
+ result = dict(loss=total_loss, log_info=add_prefix_to_keys(metrics, metrics_prefix))
260
+ if mode == 'val':
261
+ result['val_evaluator_state'] = self.val_evaluator.process_batch(batch)
262
+ elif mode == 'test':
263
+ result['test_evaluator_state'] = self.test_evaluator.process_batch(batch)
264
+ elif mode == 'extra_val':
265
+ result[f'extra_val_{extra_val_key}_evaluator_state'] = self.extra_evaluators[extra_val_key].process_batch(batch)
266
+
267
+ return result
268
+
269
+ def get_current_generator(self, no_average=False):
270
+ if not no_average and not self.training and self.average_generator and self.generator_average is not None:
271
+ return self.generator_average
272
+ return self.generator
273
+
274
+ def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
275
+ """Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys"""
276
+ raise NotImplementedError()
277
+
278
+ def generator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
279
+ raise NotImplementedError()
280
+
281
+ def discriminator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
282
+ raise NotImplementedError()
283
+
284
+ def store_discr_outputs(self, batch):
285
+ out_size = batch['image'].shape[2:]
286
+ discr_real_out, _ = self.discriminator(batch['image'])
287
+ discr_fake_out, _ = self.discriminator(batch['predicted_image'])
288
+ batch['discr_output_real'] = F.interpolate(discr_real_out, size=out_size, mode='nearest')
289
+ batch['discr_output_fake'] = F.interpolate(discr_fake_out, size=out_size, mode='nearest')
290
+ batch['discr_output_diff'] = batch['discr_output_real'] - batch['discr_output_fake']
291
+
292
+ def get_ddp_rank(self):
293
+ return self.trainer.global_rank if (self.trainer.num_nodes * self.trainer.num_processes) > 1 else None
annotator/lama/saicinpainting/training/trainers/default.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from omegaconf import OmegaConf
6
+
7
+ # from annotator.lama.saicinpainting.training.data.datasets import make_constant_area_crop_params
8
+ from annotator.lama.saicinpainting.training.losses.distance_weighting import make_mask_distance_weighter
9
+ from annotator.lama.saicinpainting.training.losses.feature_matching import feature_matching_loss, masked_l1_loss
10
+ # from annotator.lama.saicinpainting.training.modules.fake_fakes import FakeFakesGenerator
11
+ from annotator.lama.saicinpainting.training.trainers.base import BaseInpaintingTrainingModule, make_multiscale_noise
12
+ from annotator.lama.saicinpainting.utils import add_prefix_to_keys, get_ramp
13
+
14
+ LOGGER = logging.getLogger(__name__)
15
+
16
+
17
+ def make_constant_area_crop_batch(batch, **kwargs):
18
+ crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
19
+ img_width=batch['image'].shape[3],
20
+ **kwargs)
21
+ batch['image'] = batch['image'][:, :, crop_y : crop_y + crop_height, crop_x : crop_x + crop_width]
22
+ batch['mask'] = batch['mask'][:, :, crop_y: crop_y + crop_height, crop_x: crop_x + crop_width]
23
+ return batch
24
+
25
+
26
+ class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
27
+ def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image',
28
+ add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None,
29
+ distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False,
30
+ fake_fakes_proba=0, fake_fakes_generator_kwargs=None,
31
+ **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+ self.concat_mask = concat_mask
34
+ self.rescale_size_getter = get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None
35
+ self.image_to_discriminator = image_to_discriminator
36
+ self.add_noise_kwargs = add_noise_kwargs
37
+ self.noise_fill_hole = noise_fill_hole
38
+ self.const_area_crop_kwargs = const_area_crop_kwargs
39
+ self.refine_mask_for_losses = make_mask_distance_weighter(**distance_weighter_kwargs) \
40
+ if distance_weighter_kwargs is not None else None
41
+ self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
42
+
43
+ self.fake_fakes_proba = fake_fakes_proba
44
+ if self.fake_fakes_proba > 1e-3:
45
+ self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {}))
46
+
47
+ def forward(self, batch):
48
+ if self.training and self.rescale_size_getter is not None:
49
+ cur_size = self.rescale_size_getter(self.global_step)
50
+ batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
51
+ batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
52
+
53
+ if self.training and self.const_area_crop_kwargs is not None:
54
+ batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
55
+
56
+ img = batch['image']
57
+ mask = batch['mask']
58
+
59
+ masked_img = img * (1 - mask)
60
+
61
+ if self.add_noise_kwargs is not None:
62
+ noise = make_multiscale_noise(masked_img, **self.add_noise_kwargs)
63
+ if self.noise_fill_hole:
64
+ masked_img = masked_img + mask * noise[:, :masked_img.shape[1]]
65
+ masked_img = torch.cat([masked_img, noise], dim=1)
66
+
67
+ if self.concat_mask:
68
+ masked_img = torch.cat([masked_img, mask], dim=1)
69
+
70
+ batch['predicted_image'] = self.generator(masked_img)
71
+ batch['inpainted'] = mask * batch['predicted_image'] + (1 - mask) * batch['image']
72
+
73
+ if self.fake_fakes_proba > 1e-3:
74
+ if self.training and torch.rand(1).item() < self.fake_fakes_proba:
75
+ batch['fake_fakes'], batch['fake_fakes_masks'] = self.fake_fakes_gen(img, mask)
76
+ batch['use_fake_fakes'] = True
77
+ else:
78
+ batch['fake_fakes'] = torch.zeros_like(img)
79
+ batch['fake_fakes_masks'] = torch.zeros_like(mask)
80
+ batch['use_fake_fakes'] = False
81
+
82
+ batch['mask_for_losses'] = self.refine_mask_for_losses(img, batch['predicted_image'], mask) \
83
+ if self.refine_mask_for_losses is not None and self.training \
84
+ else mask
85
+
86
+ return batch
87
+
88
+ def generator_loss(self, batch):
89
+ img = batch['image']
90
+ predicted_img = batch[self.image_to_discriminator]
91
+ original_mask = batch['mask']
92
+ supervised_mask = batch['mask_for_losses']
93
+
94
+ # L1
95
+ l1_value = masked_l1_loss(predicted_img, img, supervised_mask,
96
+ self.config.losses.l1.weight_known,
97
+ self.config.losses.l1.weight_missing)
98
+
99
+ total_loss = l1_value
100
+ metrics = dict(gen_l1=l1_value)
101
+
102
+ # vgg-based perceptual loss
103
+ if self.config.losses.perceptual.weight > 0:
104
+ pl_value = self.loss_pl(predicted_img, img, mask=supervised_mask).sum() * self.config.losses.perceptual.weight
105
+ total_loss = total_loss + pl_value
106
+ metrics['gen_pl'] = pl_value
107
+
108
+ # discriminator
109
+ # adversarial_loss calls backward by itself
110
+ mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask
111
+ self.adversarial_loss.pre_generator_step(real_batch=img, fake_batch=predicted_img,
112
+ generator=self.generator, discriminator=self.discriminator)
113
+ discr_real_pred, discr_real_features = self.discriminator(img)
114
+ discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
115
+ adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss(real_batch=img,
116
+ fake_batch=predicted_img,
117
+ discr_real_pred=discr_real_pred,
118
+ discr_fake_pred=discr_fake_pred,
119
+ mask=mask_for_discr)
120
+ total_loss = total_loss + adv_gen_loss
121
+ metrics['gen_adv'] = adv_gen_loss
122
+ metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
123
+
124
+ # feature matching
125
+ if self.config.losses.feature_matching.weight > 0:
126
+ need_mask_in_fm = OmegaConf.to_container(self.config.losses.feature_matching).get('pass_mask', False)
127
+ mask_for_fm = supervised_mask if need_mask_in_fm else None
128
+ fm_value = feature_matching_loss(discr_fake_features, discr_real_features,
129
+ mask=mask_for_fm) * self.config.losses.feature_matching.weight
130
+ total_loss = total_loss + fm_value
131
+ metrics['gen_fm'] = fm_value
132
+
133
+ if self.loss_resnet_pl is not None:
134
+ resnet_pl_value = self.loss_resnet_pl(predicted_img, img)
135
+ total_loss = total_loss + resnet_pl_value
136
+ metrics['gen_resnet_pl'] = resnet_pl_value
137
+
138
+ return total_loss, metrics
139
+
140
+ def discriminator_loss(self, batch):
141
+ total_loss = 0
142
+ metrics = {}
143
+
144
+ predicted_img = batch[self.image_to_discriminator].detach()
145
+ self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=predicted_img,
146
+ generator=self.generator, discriminator=self.discriminator)
147
+ discr_real_pred, discr_real_features = self.discriminator(batch['image'])
148
+ discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
149
+ adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss(real_batch=batch['image'],
150
+ fake_batch=predicted_img,
151
+ discr_real_pred=discr_real_pred,
152
+ discr_fake_pred=discr_fake_pred,
153
+ mask=batch['mask'])
154
+ total_loss = total_loss + adv_discr_loss
155
+ metrics['discr_adv'] = adv_discr_loss
156
+ metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
157
+
158
+
159
+ if batch.get('use_fake_fakes', False):
160
+ fake_fakes = batch['fake_fakes']
161
+ self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=fake_fakes,
162
+ generator=self.generator, discriminator=self.discriminator)
163
+ discr_fake_fakes_pred, _ = self.discriminator(fake_fakes)
164
+ fake_fakes_adv_discr_loss, fake_fakes_adv_metrics = self.adversarial_loss.discriminator_loss(
165
+ real_batch=batch['image'],
166
+ fake_batch=fake_fakes,
167
+ discr_real_pred=discr_real_pred,
168
+ discr_fake_pred=discr_fake_fakes_pred,
169
+ mask=batch['mask']
170
+ )
171
+ total_loss = total_loss + fake_fakes_adv_discr_loss
172
+ metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss
173
+ metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_'))
174
+
175
+ return total_loss, metrics
annotator/lama/saicinpainting/training/visualizers/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from annotator.lama.saicinpainting.training.visualizers.directory import DirectoryVisualizer
4
+ from annotator.lama.saicinpainting.training.visualizers.noop import NoopVisualizer
5
+
6
+
7
+ def make_visualizer(kind, **kwargs):
8
+ logging.info(f'Make visualizer {kind}')
9
+
10
+ if kind == 'directory':
11
+ return DirectoryVisualizer(**kwargs)
12
+ if kind == 'noop':
13
+ return NoopVisualizer()
14
+
15
+ raise ValueError(f'Unknown visualizer kind {kind}')
annotator/lama/saicinpainting/training/visualizers/base.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Dict, List
3
+
4
+ import numpy as np
5
+ import torch
6
+ from skimage import color
7
+ from skimage.segmentation import mark_boundaries
8
+
9
+ from . import colors
10
+
11
+ COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation
12
+
13
+
14
+ class BaseVisualizer:
15
+ @abc.abstractmethod
16
+ def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
17
+ """
18
+ Take a batch, make an image from it and visualize
19
+ """
20
+ raise NotImplementedError()
21
+
22
+
23
+ def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str],
24
+ last_without_mask=True, rescale_keys=None, mask_only_first=None,
25
+ black_mask=False) -> np.ndarray:
26
+ mask = images_dict['mask'] > 0.5
27
+ result = []
28
+ for i, k in enumerate(keys):
29
+ img = images_dict[k]
30
+ img = np.transpose(img, (1, 2, 0))
31
+
32
+ if rescale_keys is not None and k in rescale_keys:
33
+ img = img - img.min()
34
+ img /= img.max() + 1e-5
35
+ if len(img.shape) == 2:
36
+ img = np.expand_dims(img, 2)
37
+
38
+ if img.shape[2] == 1:
39
+ img = np.repeat(img, 3, axis=2)
40
+ elif (img.shape[2] > 3):
41
+ img_classes = img.argmax(2)
42
+ img = color.label2rgb(img_classes, colors=COLORS)
43
+
44
+ if mask_only_first:
45
+ need_mark_boundaries = i == 0
46
+ else:
47
+ need_mark_boundaries = i < len(keys) - 1 or not last_without_mask
48
+
49
+ if need_mark_boundaries:
50
+ if black_mask:
51
+ img = img * (1 - mask[0][..., None])
52
+ img = mark_boundaries(img,
53
+ mask[0],
54
+ color=(1., 0., 0.),
55
+ outline_color=(1., 1., 1.),
56
+ mode='thick')
57
+ result.append(img)
58
+ return np.concatenate(result, axis=1)
59
+
60
+
61
+ def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10,
62
+ last_without_mask=True, rescale_keys=None) -> np.ndarray:
63
+ batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items()
64
+ if k in keys or k == 'mask'}
65
+
66
+ batch_size = next(iter(batch.values())).shape[0]
67
+ items_to_vis = min(batch_size, max_items)
68
+ result = []
69
+ for i in range(items_to_vis):
70
+ cur_dct = {k: tens[i] for k, tens in batch.items()}
71
+ result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask,
72
+ rescale_keys=rescale_keys))
73
+ return np.concatenate(result, axis=0)
annotator/lama/saicinpainting/training/visualizers/colors.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import colorsys
3
+
4
+ import numpy as np
5
+ import matplotlib
6
+ matplotlib.use('agg')
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib.colors import LinearSegmentedColormap
9
+
10
+
11
+ def generate_colors(nlabels, type='bright', first_color_black=False, last_color_black=True, verbose=False):
12
+ # https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib
13
+ """
14
+ Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks
15
+ :param nlabels: Number of labels (size of colormap)
16
+ :param type: 'bright' for strong colors, 'soft' for pastel colors
17
+ :param first_color_black: Option to use first color as black, True or False
18
+ :param last_color_black: Option to use last color as black, True or False
19
+ :param verbose: Prints the number of labels and shows the colormap. True or False
20
+ :return: colormap for matplotlib
21
+ """
22
+ if type not in ('bright', 'soft'):
23
+ print ('Please choose "bright" or "soft" for type')
24
+ return
25
+
26
+ if verbose:
27
+ print('Number of labels: ' + str(nlabels))
28
+
29
+ # Generate color map for bright colors, based on hsv
30
+ if type == 'bright':
31
+ randHSVcolors = [(np.random.uniform(low=0.0, high=1),
32
+ np.random.uniform(low=0.2, high=1),
33
+ np.random.uniform(low=0.9, high=1)) for i in range(nlabels)]
34
+
35
+ # Convert HSV list to RGB
36
+ randRGBcolors = []
37
+ for HSVcolor in randHSVcolors:
38
+ randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2]))
39
+
40
+ if first_color_black:
41
+ randRGBcolors[0] = [0, 0, 0]
42
+
43
+ if last_color_black:
44
+ randRGBcolors[-1] = [0, 0, 0]
45
+
46
+ random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
47
+
48
+ # Generate soft pastel colors, by limiting the RGB spectrum
49
+ if type == 'soft':
50
+ low = 0.6
51
+ high = 0.95
52
+ randRGBcolors = [(np.random.uniform(low=low, high=high),
53
+ np.random.uniform(low=low, high=high),
54
+ np.random.uniform(low=low, high=high)) for i in range(nlabels)]
55
+
56
+ if first_color_black:
57
+ randRGBcolors[0] = [0, 0, 0]
58
+
59
+ if last_color_black:
60
+ randRGBcolors[-1] = [0, 0, 0]
61
+ random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
62
+
63
+ # Display colorbar
64
+ if verbose:
65
+ from matplotlib import colors, colorbar
66
+ from matplotlib import pyplot as plt
67
+ fig, ax = plt.subplots(1, 1, figsize=(15, 0.5))
68
+
69
+ bounds = np.linspace(0, nlabels, nlabels + 1)
70
+ norm = colors.BoundaryNorm(bounds, nlabels)
71
+
72
+ cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None,
73
+ boundaries=bounds, format='%1i', orientation=u'horizontal')
74
+
75
+ return randRGBcolors, random_colormap
76
+
annotator/lama/saicinpainting/training/visualizers/directory.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from annotator.lama.saicinpainting.training.visualizers.base import BaseVisualizer, visualize_mask_and_images_batch
7
+ from annotator.lama.saicinpainting.utils import check_and_warn_input_range
8
+
9
+
10
+ class DirectoryVisualizer(BaseVisualizer):
11
+ DEFAULT_KEY_ORDER = 'image predicted_image inpainted'.split(' ')
12
+
13
+ def __init__(self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10,
14
+ last_without_mask=True, rescale_keys=None):
15
+ self.outdir = outdir
16
+ os.makedirs(self.outdir, exist_ok=True)
17
+ self.key_order = key_order
18
+ self.max_items_in_batch = max_items_in_batch
19
+ self.last_without_mask = last_without_mask
20
+ self.rescale_keys = rescale_keys
21
+
22
+ def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
23
+ check_and_warn_input_range(batch['image'], 0, 1, 'DirectoryVisualizer target image')
24
+ vis_img = visualize_mask_and_images_batch(batch, self.key_order, max_items=self.max_items_in_batch,
25
+ last_without_mask=self.last_without_mask,
26
+ rescale_keys=self.rescale_keys)
27
+
28
+ vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8')
29
+
30
+ curoutdir = os.path.join(self.outdir, f'epoch{epoch_i:04d}{suffix}')
31
+ os.makedirs(curoutdir, exist_ok=True)
32
+ rank_suffix = f'_r{rank}' if rank is not None else ''
33
+ out_fname = os.path.join(curoutdir, f'batch{batch_i:07d}{rank_suffix}.jpg')
34
+
35
+ vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)
36
+ cv2.imwrite(out_fname, vis_img)
annotator/lama/saicinpainting/training/visualizers/noop.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from annotator.lama.saicinpainting.training.visualizers.base import BaseVisualizer
2
+
3
+
4
+ class NoopVisualizer(BaseVisualizer):
5
+ def __init__(self, *args, **kwargs):
6
+ pass
7
+
8
+ def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
9
+ pass